diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..53c101720834ca44149a76b8e32f111ecbe682c7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,20 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/audio/0.wav filter=lfs diff=lfs merge=lfs -text +assets/audio/1.wav filter=lfs diff=lfs merge=lfs -text +assets/demo.png filter=lfs diff=lfs merge=lfs -text +assets/depth/0.png filter=lfs diff=lfs merge=lfs -text +assets/depth/1.png filter=lfs diff=lfs merge=lfs -text +assets/emergency.jpg filter=lfs diff=lfs merge=lfs -text +assets/iclr_dataset_sample.jpg filter=lfs diff=lfs merge=lfs -text +assets/languagebind.jpg filter=lfs diff=lfs merge=lfs -text +assets/languagebind_frame.jpg filter=lfs diff=lfs merge=lfs -text +assets/languagebind_result.jpg filter=lfs diff=lfs merge=lfs -text +assets/languge_result.jpg filter=lfs diff=lfs merge=lfs -text +assets/logo.jpg filter=lfs diff=lfs merge=lfs -text +assets/logo_languagebind.png filter=lfs diff=lfs merge=lfs -text +assets/result1.jpg filter=lfs diff=lfs merge=lfs -text +assets/sota.jpg filter=lfs diff=lfs merge=lfs -text +assets/video/0.mp4 filter=lfs diff=lfs merge=lfs -text +assets/video/1.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/1/1 b/1/1 new file mode 100644 index 0000000000000000000000000000000000000000..d00491fd7e5bb6fa28c517a0bb32b8b506539d4d --- /dev/null +++ b/1/1 @@ -0,0 +1 @@ +1 diff --git a/DATASETS.md b/DATASETS.md new file mode 100644 index 0000000000000000000000000000000000000000..4c74c19fb68ac16eac91b0fc07c01762ddecf6f4 --- /dev/null +++ b/DATASETS.md @@ -0,0 +1,66 @@ +## Sample data +We are releasing sample data here so that individuals who are interested can further modify the code to train it on their own data, which includes videos, text from various sources, depth, and infrared. + +
+ + + + + + + + + + +
Baidu YunGoogle CloudPeking University Yun
DATALinkLinkLink
ANNOTATIONLinkLinkLink
+
+ +## VIDAL-10M + +### Text and Video +Due to policy restrictions, we are unable to directly release the videos. However, we provide the YouTube IDs, which can be used to download the videos independently. All textual sources and YouTube IDs can be downloaded from [Google Disk](https://drive.google.com/file/d/1qgm3rO9JugazLJ6KRsAKZfLIagHu3PJ-/view?usp=sharing) or [Baidu Disk](https://pan.baidu.com/s/13gY-IcFSFIuDZ-q0hMTx0g?pwd=gum9). + +The organization format of `ANNOTATION` is as follows: +```Bash +{ + "ImkVYKWqlDU": { + "folder": "coco_vat_9", + "mplug": "This video describes a group of scuba divers rolling backwards off a boat while playing an instrument. They are having fun and enjoying their time in the water.", + "polish_mplug": "scuba divers are seen rolling backwards off a boat while playing an instrument, displaying enjoyment and having a good time in the water.", + "ofa": [ + " a man in a wet suit and a helmet on a boat", + " a man in a scuba suit on a boat", + " a person in a boat holding a diver helmet", + " a man in a wetsuit on a jet ski", + " a picture of a body of water with the words boats on it", + " a person in the water with the words if they rolled", + " a person in the water with a paddle", + " a person in the water with a scooter" + ], + "sound_mplug": "scuba divers rolling backwards off a boat while playing an instrument showcases exuberant laughter, splashing water, and cheery melodies blending with the gentle waves.", + "raw": "WHY SCUBA DIVERS ROLL BACKWARDS OFF BOAT #shorts" + }, + "id": { + "folder": "video_folder", + "mplug": "mplug_caption", + "polish_mplug": "polish_mplug_caption", + "ofa": [ + "ofa_caption_0", + "ofa_caption_1", + "ofa_caption_2", + "ofa_caption_3", + "ofa_caption_4", + "ofa_caption_5", + "ofa_caption_6", + "ofa_caption_7" + ], + "sound_mplug": "sound_mplug_caption", + "raw": "raw_caption#hashtags" + }, + ... +} +``` + +### Depth and Thermal (Infrared) + +We are uploading data to [Hugging Face](https://huggingface.co/datasets/LanguageBind/VIDAL-Depth-Thermal), but based on a conservative estimate, it's approximately **20T**. Please be patient as we work on it. diff --git a/DATASET_LICENSE b/DATASET_LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a115f899f8d09ef3b1def4a16c7bae1a0bd50fbe --- /dev/null +++ b/DATASET_LICENSE @@ -0,0 +1,400 @@ + +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..13db287e16159d9c1b059aefaccd22b3d5fea059 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 PKU-YUAN's Group (袁粒课题组-北大信工) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f86c850c7f6b1ad2997f7e142150899d9b0c79c7 --- /dev/null +++ b/README.md @@ -0,0 +1,422 @@ +

+ +

+

【ICLR 2024 🔥】LanguageBind: Extending Video-Language Pretraining to N-modality by Language-based Semantic Alignment

+
If you like our project, please give us a star ⭐ on GitHub for latest update.
+ + +
+ +[![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/LanguageBind) +[![Dataset meta](https://img.shields.io/badge/%F0%9F%A4%97%20Dataset-VIDAL-blue)](https://huggingface.co/datasets/LanguageBind/VIDAL-Depth-Thermal) +[![arXiv](https://img.shields.io/badge/Arxiv-2310.01852-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2310.01852) +[![wechat](https://img.shields.io/badge/量子位%20-black)](https://mp.weixin.qq.com/s/EFqLv_Euf5VU024zOtzkkg) +[![jiqizhixin](https://img.shields.io/badge/机器之心%20-black)](https://mp.weixin.qq.com/s/E5Tazm_vz1CADMwV0tdhnw) +[![zhihu](https://img.shields.io/badge/知乎-0084FF)](https://zhuanlan.zhihu.com/p/660567767) +[![License](https://img.shields.io/badge/Code%20License-MIT-yellow)](https://github.com/PKU-YuanGroup/LanguageBind/blob/main/LICENSE) +[![Data License](https://img.shields.io/badge/Dataset%20license-CC--BY--NC%204.0-orange)](https://github.com/PKU-YuanGroup/LanguageBind/blob/main/DATASET_LICENSE) +[![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FPKU-YuanGroup%2FLanguageBind&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visitor&edge_flat=false)](https://hits.seeyoufarm.com) +[![GitHub issues](https://img.shields.io/github/issues/PKU-YuanGroup/LanguageBind?color=critical&label=Issues)](https://github.com/PKU-YuanGroup/LanguageBind/issues?q=is%3Aopen+is%3Aissue) +[![GitHub closed issues](https://img.shields.io/github/issues-closed/PKU-YuanGroup/LanguageBind?color=success&label=Issues)](https://github.com/PKU-YuanGroup/LanguageBind/issues?q=is%3Aissue+is%3Aclosed)
+ +
+ +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-audio-classification-on-audioset)](https://paperswithcode.com/sota/zero-shot-audio-classification-on-audioset?p=languagebind-extending-video-language)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-audio-classification-on-vgg-sound)](https://paperswithcode.com/sota/zero-shot-audio-classification-on-vgg-sound?p=languagebind-extending-video-language)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-text-to-audio-retrieval-on-clotho)](https://paperswithcode.com/sota/zero-shot-text-to-audio-retrieval-on-clotho?p=languagebind-extending-video-language)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-scene-classification-unified)](https://paperswithcode.com/sota/zero-shot-scene-classification-unified?p=languagebind-extending-video-language)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-classification-unified-classes-on)](https://paperswithcode.com/sota/zero-shot-classification-unified-classes-on?p=languagebind-extending-video-language)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-video-retrieval-on-msvd)](https://paperswithcode.com/sota/zero-shot-video-retrieval-on-msvd?p=languagebind-extending-video-language)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-environment-sound-classification-on-1)](https://paperswithcode.com/sota/zero-shot-environment-sound-classification-on-1?p=languagebind-extending-video-language)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-text-to-audio-retrieval-on)](https://paperswithcode.com/sota/zero-shot-text-to-audio-retrieval-on?p=languagebind-extending-video-language)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-video-retrieval-on-activitynet)](https://paperswithcode.com/sota/zero-shot-video-retrieval-on-activitynet?p=languagebind-extending-video-language)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-video-retrieval-on-msr-vtt)](https://paperswithcode.com/sota/zero-shot-video-retrieval-on-msr-vtt?p=languagebind-extending-video-language)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-video-retrieval-on-didemo)](https://paperswithcode.com/sota/zero-shot-video-retrieval-on-didemo?p=languagebind-extending-video-language)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/languagebind-extending-video-language/zero-shot-action-recognition-on-kinetics)](https://paperswithcode.com/sota/zero-shot-action-recognition-on-kinetics?p=languagebind-extending-video-language) + +
💡 I also have other vision-language projects that may interest you ✨.

+ + +> [**Video-LLaVA: Learning United Visual Representation by Alignment Before Projection**](https://arxiv.org/abs/2311.10122)
+> Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munan Ning, Peng Jin, Li Yuan
+[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/PKU-YuanGroup/Video-LLaVA) [![github](https://img.shields.io/github/stars/PKU-YuanGroup/Video-LLaVA.svg?style=social)](https://github.com/PKU-YuanGroup/Video-LLaVA) [![arXiv](https://img.shields.io/badge/Arxiv-2311.10122-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2311.10122)
+ +> [**MoE-LLaVA: Mixture of Experts for Large Vision-Language Models**](https://github.com/PKU-YuanGroup/MoE-LLaVA/blob/main/MoE-LLaVA.pdf)
+> Bin Lin, Zhenyu Tang, Yang Ye, Jiaxi Cui, Bin Zhu, Peng Jin, Junwu Zhang, Munan Ning, Li Yuan
+[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/PKU-YuanGroup/MoE-LLaVA) [![github](https://img.shields.io/github/stars/PKU-YuanGroup/MoE-LLaVA.svg?style=social)](https://github.com/PKU-YuanGroup/MoE-LLaVA) [![arXiv](https://img.shields.io/badge/Arxiv-2401.15947-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2401.15947)
+ +> [**Video-Bench: A Comprehensive Benchmark and Toolkit for Evaluating Video-based Large Language Models**](https://arxiv.org/abs/2311.08046)
+> Munan Ning, Bin Zhu, Yujia Xie, Bin Lin, Jiaxi Cui, Lu Yuan, Dongdong Chen, Li Yuan
+[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/PKU-YuanGroup/Video-Bench) [![github](https://img.shields.io/github/stars/PKU-YuanGroup/Video-Bench.svg?style=social)](https://github.com/PKU-YuanGroup/Video-Bench) [![arXiv](https://img.shields.io/badge/Arxiv-2311.16103-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2311.16103)
+ + +

+ +## 📰 News +* **[2024.01.27]** 👀👀👀 Our [MoE-LLaVA](https://github.com/PKU-YuanGroup/MoE-LLaVA) is released! A sparse model with 3B parameters outperformed the dense model with 7B parameters. +* **[2024.01.16]** 🔥🔥🔥 Our LanguageBind has been accepted at ICLR 2024! We earn the score of 6(3)8(6)6(6)6(6) [here](https://openreview.net/forum?id=QmZKc7UZCy¬eId=OgsxQxAleA). +* **[2023.12.15]** 💪💪💪 We expand the 💥💥💥 VIDAL dataset and now have **10M video-text data**. We launch **LanguageBind_Video 1.5**, checking our [model zoo](#-model-zoo). +* **[2023.12.10]** We expand the 💥💥💥 VIDAL dataset and now have **10M depth and 10M thermal data**. We are in the process of uploading thermal and depth data on [Hugging Face](https://huggingface.co/datasets/LanguageBind/VIDAL-Depth-Thermal) and expect the whole process to last 1-2 months. +* **[2023.11.27]** 🔥🔥🔥 We have updated our [paper](https://arxiv.org/abs/2310.01852) with emergency zero-shot results., checking our ✨ [results](#emergency-results). +* **[2023.11.26]** 💥💥💥 We have open-sourced all textual sources and corresponding YouTube IDs [here](DATASETS.md). +* **[2023.11.26]** 📣📣📣 We have open-sourced fully fine-tuned **Video & Audio**, achieving improved performance once again, checking our [model zoo](#-model-zoo). +* **[2023.11.22]** We are about to release a fully fine-tuned version, and the **HUGE** version is currently undergoing training. +* **[2023.11.21]** 💥 We are releasing sample data in [DATASETS.md](DATASETS.md) so that individuals who are interested can further modify the code to train it on their own data. +* **[2023.11.20]** 🚀🚀🚀 [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA) builds a large visual-language model to achieve 🎉SOTA performances based on LanguageBind encoders. +* **[2023.10.23]** 🎶 LanguageBind-Audio achieves 🎉🎉🎉**state-of-the-art (SOTA) performance on 5 datasets**, checking our ✨ [results](#multiple-modalities)! +* **[2023.10.14]** 😱 Released a stronger LanguageBind-Video, checking our ✨ [results](#video-language)! The video checkpoint **have updated** on Huggingface Model Hub! +* **[2023.10.10]** We provide sample data, which can be found in [assets](assets), and [emergency zero-shot usage](#emergency-zero-shot) is described. +* **[2023.10.07]** The checkpoints are available on 🤗 [Huggingface Model](https://huggingface.co/LanguageBind). +* **[2023.10.04]** Code and [demo](https://huggingface.co/spaces/LanguageBind/LanguageBind) are available now! Welcome to **watch** 👀 this repository for the latest updates. + +## 😮 Highlights + +### 💡 High performance, but NO intermediate modality required +LanguageBind is a **language-centric** multimodal pretraining approach, **taking the language as the bind across different modalities** because the language modality is well-explored and contains rich semantics. +* The following first figure shows the architecture of LanguageBind. LanguageBind can be easily extended to segmentation, detection tasks, and potentially to unlimited modalities. + +### ⚡️ A multimodal, fully aligned and voluminous dataset +We propose **VIDAL-10M**, **10 Million data** with **V**ideo, **I**nfrared, **D**epth, **A**udio and their corresponding **L**anguage, which greatly expands the data beyond visual modalities. +* The second figure shows our proposed VIDAL-10M dataset, which includes five modalities: video, infrared, depth, audio, and language. + +### 🔥 Multi-view enhanced description for training +We make multi-view enhancements to language. We produce multi-view description that combines **meta-data**, **spatial**, and **temporal** to greatly enhance the semantic information of the language. In addition we further **enhance the language with ChatGPT** to create a good semantic space for each modality aligned language. + +

+ +

+

+ +

+ +## 🤗 Demo + +* **Local demo.** Highly recommend trying out our web demo, which incorporates all features currently supported by LanguageBind. +```bash +python gradio_app.py +``` + +* **Online demo.** We provide the [online demo](https://huggingface.co/spaces/LanguageBind/LanguageBind) in Huggingface Spaces. In this demo, you can calculate the similarity of modalities to language, such as audio-to-language, video-to-language, and depth-to-image. +

+ +

+ + + +## 🚀 Main Results + +### Video-Language +LanguageBind achieves **state-of-the-art (SOTA) performance on four datasets**, * donates the results of full tuning. +

+ +

+ +### Multiple Modalities +Video-Language, Infrared-Language, Depth-Language, and Audio-Language zero-shot classification, * donates the results of full tuning. +

+ +

+We report text-to-audio results for retrieval, * donates the results of full tuning. +

+ +

+ +### Emergency results +

+ +

+ +## 🛠️ Requirements and Installation +* Python >= 3.8 +* Pytorch >= 1.13.1 +* CUDA Version >= 11.6 +* Install required packages: +```bash +git clone https://github.com/PKU-YuanGroup/LanguageBind +cd LanguageBind +pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 +pip install -r requirements.txt +``` + +## 🐳 Model Zoo + +The names in the table represent different encoder models. For example, `LanguageBind/LanguageBind_Video_FT` represents the fully fine-tuned version, while `LanguageBind/LanguageBind_Video` represents the LoRA-tuned version. + +You can freely replace them in the recommended [API usage](#-api). We recommend using the fully fine-tuned version, as it offers stronger performance. + +
+ + + + + + + + + + + + + + + + +
ModalityLoRA tuningFine-tuning
VideoLanguageBind_VideoLanguageBind_Video_FT
AudioLanguageBind_AudioLanguageBind_Audio_FT
DepthLanguageBind_Depth-
ThermalLanguageBind_Thermal-
+
+ + +
+ + + + + + + + + + + + + + + + + + + + + + +
VersionTuningModel sizeNum_framesHF LinkMSR-VTTDiDeMoActivityNetMSVD
LanguageBind_VideoLoRALarge8Link42.637.835.152.2
LanguageBind_Video_FTFull-tuningLarge8Link42.738.136.953.5
LanguageBind_Video_V1.5_FTFull-tuningLarge8Link42.839.738.454.1
LanguageBind_Video_V1.5_FTFull-tuningLarge12Coming soon
LanguageBind_Video_Huge_V1.5_FTFull-tuningHuge8Link44.839.941.053.7
LanguageBind_Video_Huge_V1.5_FTFull-tuningHuge12Coming soon
+
+ +## 🤖 API +**We open source all modalities preprocessing code.** If you want to load the model (e.g. ```LanguageBind/LanguageBind_Thermal```) from the model hub on Huggingface or on local, you can use the following code snippets! +### Inference for Multi-modal Binding +We have provided some sample datasets in [assets](assets) to quickly see how languagebind works. +```python +import torch +from languagebind import LanguageBind, to_device, transform_dict, LanguageBindImageTokenizer + +if __name__ == '__main__': + device = 'cuda:0' + device = torch.device(device) + clip_type = { + 'video': 'LanguageBind_Video_FT', # also LanguageBind_Video + 'audio': 'LanguageBind_Audio_FT', # also LanguageBind_Audio + 'thermal': 'LanguageBind_Thermal', + 'image': 'LanguageBind_Image', + 'depth': 'LanguageBind_Depth', + } + + model = LanguageBind(clip_type=clip_type, cache_dir='./cache_dir') + model = model.to(device) + model.eval() + pretrained_ckpt = f'lb203/LanguageBind_Image' + tokenizer = LanguageBindImageTokenizer.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir/tokenizer_cache_dir') + modality_transform = {c: transform_dict[c](model.modality_config[c]) for c in clip_type.keys()} + + image = ['assets/image/0.jpg', 'assets/image/1.jpg'] + audio = ['assets/audio/0.wav', 'assets/audio/1.wav'] + video = ['assets/video/0.mp4', 'assets/video/1.mp4'] + depth = ['assets/depth/0.png', 'assets/depth/1.png'] + thermal = ['assets/thermal/0.jpg', 'assets/thermal/1.jpg'] + language = ["Training a parakeet to climb up a ladder.", 'A lion climbing a tree to catch a monkey.'] + + inputs = { + 'image': to_device(modality_transform['image'](image), device), + 'video': to_device(modality_transform['video'](video), device), + 'audio': to_device(modality_transform['audio'](audio), device), + 'depth': to_device(modality_transform['depth'](depth), device), + 'thermal': to_device(modality_transform['thermal'](thermal), device), + } + inputs['language'] = to_device(tokenizer(language, max_length=77, padding='max_length', + truncation=True, return_tensors='pt'), device) + + with torch.no_grad(): + embeddings = model(inputs) + + print("Video x Text: \n", + torch.softmax(embeddings['video'] @ embeddings['language'].T, dim=-1).detach().cpu().numpy()) + print("Image x Text: \n", + torch.softmax(embeddings['image'] @ embeddings['language'].T, dim=-1).detach().cpu().numpy()) + print("Depth x Text: \n", + torch.softmax(embeddings['depth'] @ embeddings['language'].T, dim=-1).detach().cpu().numpy()) + print("Audio x Text: \n", + torch.softmax(embeddings['audio'] @ embeddings['language'].T, dim=-1).detach().cpu().numpy()) + print("Thermal x Text: \n", + torch.softmax(embeddings['thermal'] @ embeddings['language'].T, dim=-1).detach().cpu().numpy()) +``` +Then returns the following result. +```bash +Video x Text: + [[9.9989331e-01 1.0667283e-04] + [1.3255903e-03 9.9867439e-01]] +Image x Text: + [[9.9990666e-01 9.3292067e-05] + [4.6132666e-08 1.0000000e+00]] +Depth x Text: + [[0.9954276 0.00457235] + [0.12042473 0.8795753 ]] +Audio x Text: + [[0.97634876 0.02365119] + [0.02917843 0.97082156]] +Thermal x Text: + [[0.9482511 0.0517489 ] + [0.48746133 0.5125386 ]] +``` +### Emergency zero-shot +Since languagebind binds each modality together, we also found the **emergency zero-shot**. It's very simple to use. +```python +print("Video x Audio: \n", torch.softmax(embeddings['video'] @ embeddings['audio'].T, dim=-1).detach().cpu().numpy()) +print("Image x Depth: \n", torch.softmax(embeddings['image'] @ embeddings['depth'].T, dim=-1).detach().cpu().numpy()) +print("Image x Thermal: \n", torch.softmax(embeddings['image'] @ embeddings['thermal'].T, dim=-1).detach().cpu().numpy()) +``` +Then, you will get: +``` +Video x Audio: + [[1.0000000e+00 0.0000000e+00] + [3.1150486e-32 1.0000000e+00]] +Image x Depth: + [[1. 0.] + [0. 1.]] +Image x Thermal: + [[1. 0.] + [0. 1.]] + ``` + +### Different branches for X-Language task +Additionally, LanguageBind can be **disassembled into different branches** to handle different tasks. Note that we do not train Image, which just initialize from OpenCLIP. +#### Thermal +```python +import torch +from languagebind import LanguageBindThermal, LanguageBindThermalTokenizer, LanguageBindThermalProcessor + +pretrained_ckpt = 'LanguageBind/LanguageBind_Thermal' +model = LanguageBindThermal.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir') +tokenizer = LanguageBindThermalTokenizer.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir') +thermal_process = LanguageBindThermalProcessor(model.config, tokenizer) + +model.eval() +data = thermal_process([r"your/thermal.jpg"], ['your text'], return_tensors='pt') +with torch.no_grad(): + out = model(**data) + +print(out.text_embeds @ out.image_embeds.T) +``` + +#### Depth +```python +import torch +from languagebind import LanguageBindDepth, LanguageBindDepthTokenizer, LanguageBindDepthProcessor + +pretrained_ckpt = 'LanguageBind/LanguageBind_Depth' +model = LanguageBindDepth.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir') +tokenizer = LanguageBindDepthTokenizer.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir') +depth_process = LanguageBindDepthProcessor(model.config, tokenizer) + +model.eval() +data = depth_process([r"your/depth.png"], ['your text.'], return_tensors='pt') +with torch.no_grad(): + out = model(**data) + +print(out.text_embeds @ out.image_embeds.T) +``` + +#### Video +```python +import torch +from languagebind import LanguageBindVideo, LanguageBindVideoTokenizer, LanguageBindVideoProcessor + +pretrained_ckpt = 'LanguageBind/LanguageBind_Video_FT' # also 'LanguageBind/LanguageBind_Video' +model = LanguageBindVideo.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir') +tokenizer = LanguageBindVideoTokenizer.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir') +video_process = LanguageBindVideoProcessor(model.config, tokenizer) + +model.eval() +data = video_process(["your/video.mp4"], ['your text.'], return_tensors='pt') +with torch.no_grad(): + out = model(**data) + +print(out.text_embeds @ out.image_embeds.T) +``` + +#### Audio +```python +import torch +from languagebind import LanguageBindAudio, LanguageBindAudioTokenizer, LanguageBindAudioProcessor + +pretrained_ckpt = 'LanguageBind/LanguageBind_Audio_FT' # also 'LanguageBind/LanguageBind_Audio' +model = LanguageBindAudio.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir') +tokenizer = LanguageBindAudioTokenizer.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir') +audio_process = LanguageBindAudioProcessor(model.config, tokenizer) + +model.eval() +data = audio_process([r"your/audio.wav"], ['your audio.'], return_tensors='pt') +with torch.no_grad(): + out = model(**data) + +print(out.text_embeds @ out.image_embeds.T) +``` + +#### Image +Note that our image encoder is the same as OpenCLIP. **Not** as fine-tuned as other modalities. +```python +import torch +from languagebind import LanguageBindImage, LanguageBindImageTokenizer, LanguageBindImageProcessor + +pretrained_ckpt = 'LanguageBind/LanguageBind_Image' +model = LanguageBindImage.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir') +tokenizer = LanguageBindImageTokenizer.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir') +image_process = LanguageBindImageProcessor(model.config, tokenizer) + +model.eval() +data = image_process([r"your/image.jpg"], ['your text.'], return_tensors='pt') +with torch.no_grad(): + out = model(**data) + +print(out.text_embeds @ out.image_embeds.T) +``` + +## 💥 VIDAL-10M +The datasets is in [DATASETS.md](DATASETS.md). + +## 🗝️ Training & Validating +The training & validating instruction is in [TRAIN_AND_VALIDATE.md](TRAIN_AND_VALIDATE.md). + +## 👍 Acknowledgement +* [OpenCLIP](https://github.com/mlfoundations/open_clip) An open source pretraining framework. +* [CLIP4Clip](https://github.com/ArrowLuo/CLIP4Clip) An open source Video-Text retrieval framework. +* [sRGB-TIR](https://github.com/rpmsnu/sRGB-TIR) An open source framework to generate infrared (thermal) images. +* [GLPN](https://github.com/vinvino02/GLPDepth) An open source framework to generate depth images. + +## 🔒 License +* The majority of this project is released under the MIT license as found in the [LICENSE](https://github.com/PKU-YuanGroup/LanguageBind/blob/main/LICENSE) file. +* The dataset of this project is released under the CC-BY-NC 4.0 license as found in the [DATASET_LICENSE](https://github.com/PKU-YuanGroup/LanguageBind/blob/main/DATASET_LICENSE) file. + +## ✏️ Citation +If you find our paper and code useful in your research, please consider giving a star :star: and citation :pencil:. + +```BibTeX +@misc{zhu2023languagebind, + title={LanguageBind: Extending Video-Language Pretraining to N-modality by Language-based Semantic Alignment}, + author={Bin Zhu and Bin Lin and Munan Ning and Yang Yan and Jiaxi Cui and Wang HongFa and Yatian Pang and Wenhao Jiang and Junwu Zhang and Zongwei Li and Cai Wan Zhang and Zhifeng Li and Wei Liu and Li Yuan}, + year={2023}, + eprint={2310.01852}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + + +## ✨ Star History + +[![Star History](https://api.star-history.com/svg?repos=PKU-YuanGroup/LanguageBind&type=Date)](https://star-history.com/#PKU-YuanGroup/LanguageBind&Date) + + +## 🤝 Contributors + + + + diff --git a/TRAIN_AND_VALIDATE.md b/TRAIN_AND_VALIDATE.md new file mode 100644 index 0000000000000000000000000000000000000000..68d2dfac7803a2fb880e838457d133d3c454b11d --- /dev/null +++ b/TRAIN_AND_VALIDATE.md @@ -0,0 +1,214 @@ +We provide the **off-the-shelf** scripts in the [scripts folder](scripts). + +## Training LanguageBind + + +
+ + + + + + + + + + +
Cache of pretrained weightBaidu YunGoogle CloudPeking University Yun
LargeLinkLinkLink
HugeLink-Link
+
+ + +For example, to **train** LanguageBind on **Depth-Language** with 8 GPUs (1 nodes x 8 GPUs). +* First download the cache of pretrained weight above. and specify `CACHE_DIR=path/to/LanguageBind`. +* The second step is to develop a path to `ANNOTATION` and `DATA` [here](https://github.com/PKU-YuanGroup/LanguageBind/blob/main/data/base_datasets.py#L37) according to the [dataset preparation](https://github.com/PKU-YuanGroup/LanguageBind#-vidal-10m). +* Then you can run + +```bash +CACHE_DIR="/path/to/LanguageBind" +ANNOTATION="path/to/data" +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nnodes=1 --nproc_per_node 8 \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 3020000 \ + --clip-type "dl" --max-depth 10 \ + --do_train \ + --lock-text --lock-image --text-type "polish_mplug" \ + --init-temp 0.07 --learn-temp \ + --model "ViT-L-14" --cache-dir ${CACHE_DIR} \ + --convert_to_lora --lora_r 2 \ + --lr 5e-4 --coef-lr 1e-3 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 1 --force-patch-dropout 0.5 \ + --epochs 1 --batch-size 128 --accum-freq 1 --warmup 200 \ + --precision "amp" --workers 10 --video-decode-backend "imgs" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume "latest" \ + --do_eval \ + --val_d_cls_data "NYUV2" +``` + + +## Validating LanguageBind + +For example, to **validate** LanguageBind on **Depth-Language** with 1 GPUs. +* First specify ```RESUME```. +* The second step is to prepare the [downstream dataset](https://github.com/PKU-YuanGroup/LanguageBind/blob/main/TRAIN_AND_VALIDATE.md#downstream-datasets). +* Then you can run + +```bash +CACHE_DIR="/path/to/LanguageBind" +RESUME="thermal_language.pt" +ANNOTATION="path/to/data" +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nproc_per_node 1 \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 3020000 \ + --clip-type "dl" --max-depth 10 \ + --lock-text --lock-image --text-type "polish_mplug" \ + --init-temp 0.07 --learn-temp \ + --model "ViT-L-14" --cache-dir ${CACHE_DIR} \ + --convert_to_lora --lora_r 2 \ + --lr 5e-4 --coef-lr 1e-3 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 1 --force-patch-dropout 0.5 \ + --epochs 1 --batch-size 128 --accum-freq 1 --warmup 200 \ + --precision "amp" --workers 10 --video-decode-backend "imgs" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume ${RESUME} \ + --do_eval \ + --val_d_cls_data "NYUV2" +``` + +## Downstream datasets + +### Depth +NYU V2 dataset is downloaded from [this repo](https://github.com/TUI-NICR/nicr-scene-analysis-datasets/tree/main/nicr_scene_analysis_datasets/datasets/nyuv2) and we reformat them to conform to the standard ImageNet format. We also provide data as follows. Change the ```data_root``` [here](https://github.com/PKU-YuanGroup/LanguageBind/blob/main/data/build_datasets.py#L221). + +
+ + + + + + + +
DatasetsBaidu YunGoogle CloudPeking University Yun
NYULinkLinkLink
+
+ +### Video +Video datasets are downloaded from [this repo](https://github.com/jpthu17/HBI) and we show the folder structure. Change the ```data_root``` [here](https://github.com/PKU-YuanGroup/LanguageBind/blob/main/data/build_datasets.py#L74). + +### Audio +Audio datasets are downloaded from [this repo](https://github.com/OFA-Sys/ONE-PEACE/blob/main/datasets.md#audio) and Audioset from [here](https://github.com/qiuqiangkong/audioset_tagging_cnn#1-download-dataset).We reformat them to conform to the standard ImageNet format. Change the ```data_root``` [here1](https://github.com/PKU-YuanGroup/LanguageBind/blob/main/data/build_datasets.py#L144) and [here2](https://github.com/PKU-YuanGroup/LanguageBind/blob/main/data/build_datasets.py#L159). + +### Infrared (Thermal) +We download LLVIP from [official website](https://bupt-ai-cz.github.io/LLVIP/), and FLIR from [here](https://www.flir.com/oem/adas/adas-dataset-form/). We reformat them to conform to the standard ImageNet format. Change the ```data_root``` [here](https://github.com/PKU-YuanGroup/LanguageBind/blob/main/data/build_datasets.py#L233). We also provide the processed data as follows. + +
+ + + + + + + + + + + + + +
DatasetsBaidu YunGoogle CloudPeking University Yun
LLVIPLinkLinkLink
FLIR V1LinkLinkLink
FLIR V2LinkLinkLink
+
+ +### Folder structure +```bash +downstream_datasets +├── Audio +│   ├── audiocaps +│   │ └── audio +│   │ ├── test +│   │ ├── train +│   │ └── val +│ ├── audioset +│   │ ├── balanced_train_segments +│   │ ├── eval_segments +│   │ └── unbalanced_train_segments +│   │ ├── unbalanced_train_segments_part00 +│   │ ├── unbalanced_train_segments_part01 +│   │ ├── ... +│   │ └── unbalanced_train_segments_part40 +│ ├── clotho +│   │ ├── CLOTHO_retrieval_dataset +│   │ └── evaluation +│ ├── esc50 +│   │ └── test +│   │ ├── airplane +│   │ ├── breathing +│   │ ├── ... +│   │ └── wind +├── laionaudio +│   │ ├── audios +│   │ ├── freesound_no_overlap +│   │ └── jsons +├── vggsound +│ └── test +│ ├── air\ conditioning\ noise +│ ├── air\ horn +│ ├── ... +│ └── zebra\ braying +├── Depth +│   ├── nyuv2 +│   │   ├── data +│   │   │   └── val +│   │   │   ├── bathroom +│   │   │   ├── bedroom +│   │   │   ├── bookstore +│   │   │   ├── classroom +│   │   │   ├── dining_room +│   │   │   ├── home_office +│   │   │   ├── kitchen +│   │   │   ├── living_room +│   │   │   ├── office +│   │   │   └── others +├── Thermal +│   ├── flirv1 +│   │   └── val +│   │   ├── bicycle +│   │   ├── car +│   │   ├── dog +│   │   └── person +│   ├── flirv2 +│   │   └── val +│   │   ├── bike +│   │   ├── bus +│   │   ├── car +│   │   ├── hydrant +│   │   ├── light +│   │   ├── motor +│   │   ├── other\ vehicle +│   │   ├── person +│   │   ├── sign +│   │   ├── skateboard +│   │   ├── stroller +│   │   └── truck +│   ├── llvip +│   │   ├── train +│   │   │   ├── background +│   │   │   └── person +│   │   └── val +│   │   ├── background +│   │   └── person +└── VideoTextRetrieval + ├── vtRetdata + │   ├── ActivityNet + │   │   └── Videos + │   │   └── Activity_Videos + │   ├── Didemo + │   │   └── videos + │   ├── MSRVTT + │   │   └── MSRVTT_Videos + │   └── MSVD + │   └── MSVD_Videos +``` + diff --git a/a_cls/class_labels_indices.csv b/a_cls/class_labels_indices.csv new file mode 100644 index 0000000000000000000000000000000000000000..3a2767e81114adecde59992cf6607f31c1862f4c --- /dev/null +++ b/a_cls/class_labels_indices.csv @@ -0,0 +1,528 @@ +index,mid,display_name +0,/m/09x0r,"Speech" +1,/m/05zppz,"Male speech, man speaking" +2,/m/02zsn,"Female speech, woman speaking" +3,/m/0ytgt,"Child speech, kid speaking" +4,/m/01h8n0,"Conversation" +5,/m/02qldy,"Narration, monologue" +6,/m/0261r1,"Babbling" +7,/m/0brhx,"Speech synthesizer" +8,/m/07p6fty,"Shout" +9,/m/07q4ntr,"Bellow" +10,/m/07rwj3x,"Whoop" +11,/m/07sr1lc,"Yell" +12,/m/04gy_2,"Battle cry" +13,/t/dd00135,"Children shouting" +14,/m/03qc9zr,"Screaming" +15,/m/02rtxlg,"Whispering" +16,/m/01j3sz,"Laughter" +17,/t/dd00001,"Baby laughter" +18,/m/07r660_,"Giggle" +19,/m/07s04w4,"Snicker" +20,/m/07sq110,"Belly laugh" +21,/m/07rgt08,"Chuckle, chortle" +22,/m/0463cq4,"Crying, sobbing" +23,/t/dd00002,"Baby cry, infant cry" +24,/m/07qz6j3,"Whimper" +25,/m/07qw_06,"Wail, moan" +26,/m/07plz5l,"Sigh" +27,/m/015lz1,"Singing" +28,/m/0l14jd,"Choir" +29,/m/01swy6,"Yodeling" +30,/m/02bk07,"Chant" +31,/m/01c194,"Mantra" +32,/t/dd00003,"Male singing" +33,/t/dd00004,"Female singing" +34,/t/dd00005,"Child singing" +35,/t/dd00006,"Synthetic singing" +36,/m/06bxc,"Rapping" +37,/m/02fxyj,"Humming" +38,/m/07s2xch,"Groan" +39,/m/07r4k75,"Grunt" +40,/m/01w250,"Whistling" +41,/m/0lyf6,"Breathing" +42,/m/07mzm6,"Wheeze" +43,/m/01d3sd,"Snoring" +44,/m/07s0dtb,"Gasp" +45,/m/07pyy8b,"Pant" +46,/m/07q0yl5,"Snort" +47,/m/01b_21,"Cough" +48,/m/0dl9sf8,"Throat clearing" +49,/m/01hsr_,"Sneeze" +50,/m/07ppn3j,"Sniff" +51,/m/06h7j,"Run" +52,/m/07qv_x_,"Shuffle" +53,/m/07pbtc8,"Walk, footsteps" +54,/m/03cczk,"Chewing, mastication" +55,/m/07pdhp0,"Biting" +56,/m/0939n_,"Gargling" +57,/m/01g90h,"Stomach rumble" +58,/m/03q5_w,"Burping, eructation" +59,/m/02p3nc,"Hiccup" +60,/m/02_nn,"Fart" +61,/m/0k65p,"Hands" +62,/m/025_jnm,"Finger snapping" +63,/m/0l15bq,"Clapping" +64,/m/01jg02,"Heart sounds, heartbeat" +65,/m/01jg1z,"Heart murmur" +66,/m/053hz1,"Cheering" +67,/m/028ght,"Applause" +68,/m/07rkbfh,"Chatter" +69,/m/03qtwd,"Crowd" +70,/m/07qfr4h,"Hubbub, speech noise, speech babble" +71,/t/dd00013,"Children playing" +72,/m/0jbk,"Animal" +73,/m/068hy,"Domestic animals, pets" +74,/m/0bt9lr,"Dog" +75,/m/05tny_,"Bark" +76,/m/07r_k2n,"Yip" +77,/m/07qf0zm,"Howl" +78,/m/07rc7d9,"Bow-wow" +79,/m/0ghcn6,"Growling" +80,/t/dd00136,"Whimper (dog)" +81,/m/01yrx,"Cat" +82,/m/02yds9,"Purr" +83,/m/07qrkrw,"Meow" +84,/m/07rjwbb,"Hiss" +85,/m/07r81j2,"Caterwaul" +86,/m/0ch8v,"Livestock, farm animals, working animals" +87,/m/03k3r,"Horse" +88,/m/07rv9rh,"Clip-clop" +89,/m/07q5rw0,"Neigh, whinny" +90,/m/01xq0k1,"Cattle, bovinae" +91,/m/07rpkh9,"Moo" +92,/m/0239kh,"Cowbell" +93,/m/068zj,"Pig" +94,/t/dd00018,"Oink" +95,/m/03fwl,"Goat" +96,/m/07q0h5t,"Bleat" +97,/m/07bgp,"Sheep" +98,/m/025rv6n,"Fowl" +99,/m/09b5t,"Chicken, rooster" +100,/m/07st89h,"Cluck" +101,/m/07qn5dc,"Crowing, cock-a-doodle-doo" +102,/m/01rd7k,"Turkey" +103,/m/07svc2k,"Gobble" +104,/m/09ddx,"Duck" +105,/m/07qdb04,"Quack" +106,/m/0dbvp,"Goose" +107,/m/07qwf61,"Honk" +108,/m/01280g,"Wild animals" +109,/m/0cdnk,"Roaring cats (lions, tigers)" +110,/m/04cvmfc,"Roar" +111,/m/015p6,"Bird" +112,/m/020bb7,"Bird vocalization, bird call, bird song" +113,/m/07pggtn,"Chirp, tweet" +114,/m/07sx8x_,"Squawk" +115,/m/0h0rv,"Pigeon, dove" +116,/m/07r_25d,"Coo" +117,/m/04s8yn,"Crow" +118,/m/07r5c2p,"Caw" +119,/m/09d5_,"Owl" +120,/m/07r_80w,"Hoot" +121,/m/05_wcq,"Bird flight, flapping wings" +122,/m/01z5f,"Canidae, dogs, wolves" +123,/m/06hps,"Rodents, rats, mice" +124,/m/04rmv,"Mouse" +125,/m/07r4gkf,"Patter" +126,/m/03vt0,"Insect" +127,/m/09xqv,"Cricket" +128,/m/09f96,"Mosquito" +129,/m/0h2mp,"Fly, housefly" +130,/m/07pjwq1,"Buzz" +131,/m/01h3n,"Bee, wasp, etc." +132,/m/09ld4,"Frog" +133,/m/07st88b,"Croak" +134,/m/078jl,"Snake" +135,/m/07qn4z3,"Rattle" +136,/m/032n05,"Whale vocalization" +137,/m/04rlf,"Music" +138,/m/04szw,"Musical instrument" +139,/m/0fx80y,"Plucked string instrument" +140,/m/0342h,"Guitar" +141,/m/02sgy,"Electric guitar" +142,/m/018vs,"Bass guitar" +143,/m/042v_gx,"Acoustic guitar" +144,/m/06w87,"Steel guitar, slide guitar" +145,/m/01glhc,"Tapping (guitar technique)" +146,/m/07s0s5r,"Strum" +147,/m/018j2,"Banjo" +148,/m/0jtg0,"Sitar" +149,/m/04rzd,"Mandolin" +150,/m/01bns_,"Zither" +151,/m/07xzm,"Ukulele" +152,/m/05148p4,"Keyboard (musical)" +153,/m/05r5c,"Piano" +154,/m/01s0ps,"Electric piano" +155,/m/013y1f,"Organ" +156,/m/03xq_f,"Electronic organ" +157,/m/03gvt,"Hammond organ" +158,/m/0l14qv,"Synthesizer" +159,/m/01v1d8,"Sampler" +160,/m/03q5t,"Harpsichord" +161,/m/0l14md,"Percussion" +162,/m/02hnl,"Drum kit" +163,/m/0cfdd,"Drum machine" +164,/m/026t6,"Drum" +165,/m/06rvn,"Snare drum" +166,/m/03t3fj,"Rimshot" +167,/m/02k_mr,"Drum roll" +168,/m/0bm02,"Bass drum" +169,/m/011k_j,"Timpani" +170,/m/01p970,"Tabla" +171,/m/01qbl,"Cymbal" +172,/m/03qtq,"Hi-hat" +173,/m/01sm1g,"Wood block" +174,/m/07brj,"Tambourine" +175,/m/05r5wn,"Rattle (instrument)" +176,/m/0xzly,"Maraca" +177,/m/0mbct,"Gong" +178,/m/016622,"Tubular bells" +179,/m/0j45pbj,"Mallet percussion" +180,/m/0dwsp,"Marimba, xylophone" +181,/m/0dwtp,"Glockenspiel" +182,/m/0dwt5,"Vibraphone" +183,/m/0l156b,"Steelpan" +184,/m/05pd6,"Orchestra" +185,/m/01kcd,"Brass instrument" +186,/m/0319l,"French horn" +187,/m/07gql,"Trumpet" +188,/m/07c6l,"Trombone" +189,/m/0l14_3,"Bowed string instrument" +190,/m/02qmj0d,"String section" +191,/m/07y_7,"Violin, fiddle" +192,/m/0d8_n,"Pizzicato" +193,/m/01xqw,"Cello" +194,/m/02fsn,"Double bass" +195,/m/085jw,"Wind instrument, woodwind instrument" +196,/m/0l14j_,"Flute" +197,/m/06ncr,"Saxophone" +198,/m/01wy6,"Clarinet" +199,/m/03m5k,"Harp" +200,/m/0395lw,"Bell" +201,/m/03w41f,"Church bell" +202,/m/027m70_,"Jingle bell" +203,/m/0gy1t2s,"Bicycle bell" +204,/m/07n_g,"Tuning fork" +205,/m/0f8s22,"Chime" +206,/m/026fgl,"Wind chime" +207,/m/0150b9,"Change ringing (campanology)" +208,/m/03qjg,"Harmonica" +209,/m/0mkg,"Accordion" +210,/m/0192l,"Bagpipes" +211,/m/02bxd,"Didgeridoo" +212,/m/0l14l2,"Shofar" +213,/m/07kc_,"Theremin" +214,/m/0l14t7,"Singing bowl" +215,/m/01hgjl,"Scratching (performance technique)" +216,/m/064t9,"Pop music" +217,/m/0glt670,"Hip hop music" +218,/m/02cz_7,"Beatboxing" +219,/m/06by7,"Rock music" +220,/m/03lty,"Heavy metal" +221,/m/05r6t,"Punk rock" +222,/m/0dls3,"Grunge" +223,/m/0dl5d,"Progressive rock" +224,/m/07sbbz2,"Rock and roll" +225,/m/05w3f,"Psychedelic rock" +226,/m/06j6l,"Rhythm and blues" +227,/m/0gywn,"Soul music" +228,/m/06cqb,"Reggae" +229,/m/01lyv,"Country" +230,/m/015y_n,"Swing music" +231,/m/0gg8l,"Bluegrass" +232,/m/02x8m,"Funk" +233,/m/02w4v,"Folk music" +234,/m/06j64v,"Middle Eastern music" +235,/m/03_d0,"Jazz" +236,/m/026z9,"Disco" +237,/m/0ggq0m,"Classical music" +238,/m/05lls,"Opera" +239,/m/02lkt,"Electronic music" +240,/m/03mb9,"House music" +241,/m/07gxw,"Techno" +242,/m/07s72n,"Dubstep" +243,/m/0283d,"Drum and bass" +244,/m/0m0jc,"Electronica" +245,/m/08cyft,"Electronic dance music" +246,/m/0fd3y,"Ambient music" +247,/m/07lnk,"Trance music" +248,/m/0g293,"Music of Latin America" +249,/m/0ln16,"Salsa music" +250,/m/0326g,"Flamenco" +251,/m/0155w,"Blues" +252,/m/05fw6t,"Music for children" +253,/m/02v2lh,"New-age music" +254,/m/0y4f8,"Vocal music" +255,/m/0z9c,"A capella" +256,/m/0164x2,"Music of Africa" +257,/m/0145m,"Afrobeat" +258,/m/02mscn,"Christian music" +259,/m/016cjb,"Gospel music" +260,/m/028sqc,"Music of Asia" +261,/m/015vgc,"Carnatic music" +262,/m/0dq0md,"Music of Bollywood" +263,/m/06rqw,"Ska" +264,/m/02p0sh1,"Traditional music" +265,/m/05rwpb,"Independent music" +266,/m/074ft,"Song" +267,/m/025td0t,"Background music" +268,/m/02cjck,"Theme music" +269,/m/03r5q_,"Jingle (music)" +270,/m/0l14gg,"Soundtrack music" +271,/m/07pkxdp,"Lullaby" +272,/m/01z7dr,"Video game music" +273,/m/0140xf,"Christmas music" +274,/m/0ggx5q,"Dance music" +275,/m/04wptg,"Wedding music" +276,/t/dd00031,"Happy music" +277,/t/dd00032,"Funny music" +278,/t/dd00033,"Sad music" +279,/t/dd00034,"Tender music" +280,/t/dd00035,"Exciting music" +281,/t/dd00036,"Angry music" +282,/t/dd00037,"Scary music" +283,/m/03m9d0z,"Wind" +284,/m/09t49,"Rustling leaves" +285,/t/dd00092,"Wind noise (microphone)" +286,/m/0jb2l,"Thunderstorm" +287,/m/0ngt1,"Thunder" +288,/m/0838f,"Water" +289,/m/06mb1,"Rain" +290,/m/07r10fb,"Raindrop" +291,/t/dd00038,"Rain on surface" +292,/m/0j6m2,"Stream" +293,/m/0j2kx,"Waterfall" +294,/m/05kq4,"Ocean" +295,/m/034srq,"Waves, surf" +296,/m/06wzb,"Steam" +297,/m/07swgks,"Gurgling" +298,/m/02_41,"Fire" +299,/m/07pzfmf,"Crackle" +300,/m/07yv9,"Vehicle" +301,/m/019jd,"Boat, Water vehicle" +302,/m/0hsrw,"Sailboat, sailing ship" +303,/m/056ks2,"Rowboat, canoe, kayak" +304,/m/02rlv9,"Motorboat, speedboat" +305,/m/06q74,"Ship" +306,/m/012f08,"Motor vehicle (road)" +307,/m/0k4j,"Car" +308,/m/0912c9,"Vehicle horn, car horn, honking" +309,/m/07qv_d5,"Toot" +310,/m/02mfyn,"Car alarm" +311,/m/04gxbd,"Power windows, electric windows" +312,/m/07rknqz,"Skidding" +313,/m/0h9mv,"Tire squeal" +314,/t/dd00134,"Car passing by" +315,/m/0ltv,"Race car, auto racing" +316,/m/07r04,"Truck" +317,/m/0gvgw0,"Air brake" +318,/m/05x_td,"Air horn, truck horn" +319,/m/02rhddq,"Reversing beeps" +320,/m/03cl9h,"Ice cream truck, ice cream van" +321,/m/01bjv,"Bus" +322,/m/03j1ly,"Emergency vehicle" +323,/m/04qvtq,"Police car (siren)" +324,/m/012n7d,"Ambulance (siren)" +325,/m/012ndj,"Fire engine, fire truck (siren)" +326,/m/04_sv,"Motorcycle" +327,/m/0btp2,"Traffic noise, roadway noise" +328,/m/06d_3,"Rail transport" +329,/m/07jdr,"Train" +330,/m/04zmvq,"Train whistle" +331,/m/0284vy3,"Train horn" +332,/m/01g50p,"Railroad car, train wagon" +333,/t/dd00048,"Train wheels squealing" +334,/m/0195fx,"Subway, metro, underground" +335,/m/0k5j,"Aircraft" +336,/m/014yck,"Aircraft engine" +337,/m/04229,"Jet engine" +338,/m/02l6bg,"Propeller, airscrew" +339,/m/09ct_,"Helicopter" +340,/m/0cmf2,"Fixed-wing aircraft, airplane" +341,/m/0199g,"Bicycle" +342,/m/06_fw,"Skateboard" +343,/m/02mk9,"Engine" +344,/t/dd00065,"Light engine (high frequency)" +345,/m/08j51y,"Dental drill, dentist's drill" +346,/m/01yg9g,"Lawn mower" +347,/m/01j4z9,"Chainsaw" +348,/t/dd00066,"Medium engine (mid frequency)" +349,/t/dd00067,"Heavy engine (low frequency)" +350,/m/01h82_,"Engine knocking" +351,/t/dd00130,"Engine starting" +352,/m/07pb8fc,"Idling" +353,/m/07q2z82,"Accelerating, revving, vroom" +354,/m/02dgv,"Door" +355,/m/03wwcy,"Doorbell" +356,/m/07r67yg,"Ding-dong" +357,/m/02y_763,"Sliding door" +358,/m/07rjzl8,"Slam" +359,/m/07r4wb8,"Knock" +360,/m/07qcpgn,"Tap" +361,/m/07q6cd_,"Squeak" +362,/m/0642b4,"Cupboard open or close" +363,/m/0fqfqc,"Drawer open or close" +364,/m/04brg2,"Dishes, pots, and pans" +365,/m/023pjk,"Cutlery, silverware" +366,/m/07pn_8q,"Chopping (food)" +367,/m/0dxrf,"Frying (food)" +368,/m/0fx9l,"Microwave oven" +369,/m/02pjr4,"Blender" +370,/m/02jz0l,"Water tap, faucet" +371,/m/0130jx,"Sink (filling or washing)" +372,/m/03dnzn,"Bathtub (filling or washing)" +373,/m/03wvsk,"Hair dryer" +374,/m/01jt3m,"Toilet flush" +375,/m/012xff,"Toothbrush" +376,/m/04fgwm,"Electric toothbrush" +377,/m/0d31p,"Vacuum cleaner" +378,/m/01s0vc,"Zipper (clothing)" +379,/m/03v3yw,"Keys jangling" +380,/m/0242l,"Coin (dropping)" +381,/m/01lsmm,"Scissors" +382,/m/02g901,"Electric shaver, electric razor" +383,/m/05rj2,"Shuffling cards" +384,/m/0316dw,"Typing" +385,/m/0c2wf,"Typewriter" +386,/m/01m2v,"Computer keyboard" +387,/m/081rb,"Writing" +388,/m/07pp_mv,"Alarm" +389,/m/07cx4,"Telephone" +390,/m/07pp8cl,"Telephone bell ringing" +391,/m/01hnzm,"Ringtone" +392,/m/02c8p,"Telephone dialing, DTMF" +393,/m/015jpf,"Dial tone" +394,/m/01z47d,"Busy signal" +395,/m/046dlr,"Alarm clock" +396,/m/03kmc9,"Siren" +397,/m/0dgbq,"Civil defense siren" +398,/m/030rvx,"Buzzer" +399,/m/01y3hg,"Smoke detector, smoke alarm" +400,/m/0c3f7m,"Fire alarm" +401,/m/04fq5q,"Foghorn" +402,/m/0l156k,"Whistle" +403,/m/06hck5,"Steam whistle" +404,/t/dd00077,"Mechanisms" +405,/m/02bm9n,"Ratchet, pawl" +406,/m/01x3z,"Clock" +407,/m/07qjznt,"Tick" +408,/m/07qjznl,"Tick-tock" +409,/m/0l7xg,"Gears" +410,/m/05zc1,"Pulleys" +411,/m/0llzx,"Sewing machine" +412,/m/02x984l,"Mechanical fan" +413,/m/025wky1,"Air conditioning" +414,/m/024dl,"Cash register" +415,/m/01m4t,"Printer" +416,/m/0dv5r,"Camera" +417,/m/07bjf,"Single-lens reflex camera" +418,/m/07k1x,"Tools" +419,/m/03l9g,"Hammer" +420,/m/03p19w,"Jackhammer" +421,/m/01b82r,"Sawing" +422,/m/02p01q,"Filing (rasp)" +423,/m/023vsd,"Sanding" +424,/m/0_ksk,"Power tool" +425,/m/01d380,"Drill" +426,/m/014zdl,"Explosion" +427,/m/032s66,"Gunshot, gunfire" +428,/m/04zjc,"Machine gun" +429,/m/02z32qm,"Fusillade" +430,/m/0_1c,"Artillery fire" +431,/m/073cg4,"Cap gun" +432,/m/0g6b5,"Fireworks" +433,/g/122z_qxw,"Firecracker" +434,/m/07qsvvw,"Burst, pop" +435,/m/07pxg6y,"Eruption" +436,/m/07qqyl4,"Boom" +437,/m/083vt,"Wood" +438,/m/07pczhz,"Chop" +439,/m/07pl1bw,"Splinter" +440,/m/07qs1cx,"Crack" +441,/m/039jq,"Glass" +442,/m/07q7njn,"Chink, clink" +443,/m/07rn7sz,"Shatter" +444,/m/04k94,"Liquid" +445,/m/07rrlb6,"Splash, splatter" +446,/m/07p6mqd,"Slosh" +447,/m/07qlwh6,"Squish" +448,/m/07r5v4s,"Drip" +449,/m/07prgkl,"Pour" +450,/m/07pqc89,"Trickle, dribble" +451,/t/dd00088,"Gush" +452,/m/07p7b8y,"Fill (with liquid)" +453,/m/07qlf79,"Spray" +454,/m/07ptzwd,"Pump (liquid)" +455,/m/07ptfmf,"Stir" +456,/m/0dv3j,"Boiling" +457,/m/0790c,"Sonar" +458,/m/0dl83,"Arrow" +459,/m/07rqsjt,"Whoosh, swoosh, swish" +460,/m/07qnq_y,"Thump, thud" +461,/m/07rrh0c,"Thunk" +462,/m/0b_fwt,"Electronic tuner" +463,/m/02rr_,"Effects unit" +464,/m/07m2kt,"Chorus effect" +465,/m/018w8,"Basketball bounce" +466,/m/07pws3f,"Bang" +467,/m/07ryjzk,"Slap, smack" +468,/m/07rdhzs,"Whack, thwack" +469,/m/07pjjrj,"Smash, crash" +470,/m/07pc8lb,"Breaking" +471,/m/07pqn27,"Bouncing" +472,/m/07rbp7_,"Whip" +473,/m/07pyf11,"Flap" +474,/m/07qb_dv,"Scratch" +475,/m/07qv4k0,"Scrape" +476,/m/07pdjhy,"Rub" +477,/m/07s8j8t,"Roll" +478,/m/07plct2,"Crushing" +479,/t/dd00112,"Crumpling, crinkling" +480,/m/07qcx4z,"Tearing" +481,/m/02fs_r,"Beep, bleep" +482,/m/07qwdck,"Ping" +483,/m/07phxs1,"Ding" +484,/m/07rv4dm,"Clang" +485,/m/07s02z0,"Squeal" +486,/m/07qh7jl,"Creak" +487,/m/07qwyj0,"Rustle" +488,/m/07s34ls,"Whir" +489,/m/07qmpdm,"Clatter" +490,/m/07p9k1k,"Sizzle" +491,/m/07qc9xj,"Clicking" +492,/m/07rwm0c,"Clickety-clack" +493,/m/07phhsh,"Rumble" +494,/m/07qyrcz,"Plop" +495,/m/07qfgpx,"Jingle, tinkle" +496,/m/07rcgpl,"Hum" +497,/m/07p78v5,"Zing" +498,/t/dd00121,"Boing" +499,/m/07s12q4,"Crunch" +500,/m/028v0c,"Silence" +501,/m/01v_m0,"Sine wave" +502,/m/0b9m1,"Harmonic" +503,/m/0hdsk,"Chirp tone" +504,/m/0c1dj,"Sound effect" +505,/m/07pt_g0,"Pulse" +506,/t/dd00125,"Inside, small room" +507,/t/dd00126,"Inside, large room or hall" +508,/t/dd00127,"Inside, public space" +509,/t/dd00128,"Outside, urban or manmade" +510,/t/dd00129,"Outside, rural or natural" +511,/m/01b9nn,"Reverberation" +512,/m/01jnbd,"Echo" +513,/m/096m7z,"Noise" +514,/m/06_y0by,"Environmental noise" +515,/m/07rgkc5,"Static" +516,/m/06xkwv,"Mains hum" +517,/m/0g12c5,"Distortion" +518,/m/08p9q4,"Sidetone" +519,/m/07szfh9,"Cacophony" +520,/m/0chx_,"White noise" +521,/m/0cj0r,"Pink noise" +522,/m/07p_0gm,"Throbbing" +523,/m/01jwx6,"Vibration" +524,/m/07c52,"Television" +525,/m/06bz3,"Radio" +526,/m/07hvw1,"Field recording" diff --git a/a_cls/dataloader.py b/a_cls/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac3659485fadc898a521aa5b8546b5f2a3a4721 --- /dev/null +++ b/a_cls/dataloader.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# @Time : 6/19/21 12:23 AM +# @Author : Yuan Gong +# @Affiliation : Massachusetts Institute of Technology +# @Email : yuangong@mit.edu +# @File : dataloader.py + +# modified from: +# Author: David Harwath +# with some functions borrowed from https://github.com/SeanNaren/deepspeech.pytorch + +import csv +import json +import logging + +import torchaudio +import numpy as np +import torch +import torch.nn.functional +from torch.utils.data import Dataset +import random + +def make_midname_dict(label_csv): + index_lookup = {} + with open(label_csv, 'r') as f: + csv_reader = csv.DictReader(f) + line_count = 0 + for row in csv_reader: + index_lookup[row['mid']] = row['display_name'] + line_count += 1 + return index_lookup + +def make_index_dict(label_csv): + index_lookup = {} + with open(label_csv, 'r') as f: + csv_reader = csv.DictReader(f) + line_count = 0 + for row in csv_reader: + index_lookup[row['mid']] = row['index'] + line_count += 1 + return index_lookup + +def make_name_dict(label_csv): + name_lookup = {} + with open(label_csv, 'r') as f: + csv_reader = csv.DictReader(f) + line_count = 0 + for row in csv_reader: + name_lookup[row['index']] = row['display_name'] + line_count += 1 + return name_lookup + +def lookup_list(index_list, label_csv): + label_list = [] + table = make_name_dict(label_csv) + for item in index_list: + label_list.append(table[item]) + return label_list + +def preemphasis(signal,coeff=0.97): + """perform preemphasis on the input signal. + + :param signal: The signal to filter. + :param coeff: The preemphasis coefficient. 0 is none, default 0.97. + :returns: the filtered signal. + """ + return np.append(signal[0],signal[1:]-coeff*signal[:-1]) + +class AudiosetDataset(Dataset): + def __init__(self, dataset_json_file, audio_conf, label_csv=None): + """ + Dataset that manages audio recordings + :param audio_conf: Dictionary containing the audio loading and preprocessing settings + :param dataset_json_file + """ + self.datapath = dataset_json_file + with open(dataset_json_file, 'r') as fp: + data_json = json.load(fp) + self.data = data_json['data'] + self.index_dict = make_index_dict(label_csv) + self.label_num = len(self.index_dict) + + def __getitem__(self, index): + datum = self.data[index] + label_indices = np.zeros(self.label_num) + try: + fbank, mix_lambda = self._wav2fbank(datum['wav']) + except Exception as e: + logging.warning(f"Error at {datum['wav']} with \"{e}\"") + return self.__getitem__(random.randint(0, self.__len__()-1)) + for label_str in datum['labels'].split(','): + label_indices[int(self.index_dict[label_str])] = 1.0 + + label_indices = torch.FloatTensor(label_indices) + + + return fbank, label_indices + + def __len__(self): + return len(self.data) \ No newline at end of file diff --git a/a_cls/datasets.py b/a_cls/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1289d6af2e5c5fe88dbdab7373b5001e031b14 --- /dev/null +++ b/a_cls/datasets.py @@ -0,0 +1,93 @@ +import os.path + +import torch + +from data.build_datasets import DataInfo +from data.process_audio import get_audio_transform, torchaudio_loader +from torchvision import datasets + +# -*- coding: utf-8 -*- +# @Time : 6/19/21 12:23 AM +# @Author : Yuan Gong +# @Affiliation : Massachusetts Institute of Technology +# @Email : yuangong@mit.edu +# @File : dataloader.py + +# modified from: +# Author: David Harwath +# with some functions borrowed from https://github.com/SeanNaren/deepspeech.pytorch + +import csv +import json +import logging + +import torchaudio +import numpy as np +import torch +import torch.nn.functional +from torch.utils.data import Dataset +import random + + +def make_index_dict(label_csv): + index_lookup = {} + with open(label_csv, 'r') as f: + csv_reader = csv.DictReader(f) + line_count = 0 + for row in csv_reader: + index_lookup[row['mid']] = row['index'] + line_count += 1 + return index_lookup + + +class AudiosetDataset(Dataset): + def __init__(self, args, transform, loader): + self.audio_root = '/apdcephfs_cq3/share_1311970/downstream_datasets/Audio/audioset/eval_segments' + dataset_json_file = '/apdcephfs_cq3/share_1311970/downstream_datasets/Audio/audioset/filter_eval.json' + label_csv = '/apdcephfs_cq3/share_1311970/downstream_datasets/Audio/audioset/class_labels_indices.csv' + with open(dataset_json_file, 'r') as fp: + data_json = json.load(fp) + self.data = data_json['data'] + self.index_dict = make_index_dict(label_csv) + self.label_num = len(self.index_dict) + + self.args = args + self.transform = transform + self.loader = loader + + def __getitem__(self, index): + datum = self.data[index] + label_indices = np.zeros(self.label_num) + for label_str in datum['labels'].split(','): + label_indices[int(self.index_dict[label_str])] = 1.0 + label_indices = torch.FloatTensor(label_indices) + + audio = self.loader(os.path.join(self.audio_root, datum['wav'])) + audio_data = self.transform(audio) + return audio_data, label_indices + + def __len__(self): + return len(self.data) + + + +def is_valid_file(path): + return True + +def get_audio_dataset(args): + data_path = args.audio_data_path + transform = get_audio_transform(args) + + if args.val_a_cls_data.lower() == 'audioset': + dataset = AudiosetDataset(args, transform=transform, loader=torchaudio_loader) + else: + dataset = datasets.ImageFolder(data_path, transform=transform, loader=torchaudio_loader, is_valid_file=is_valid_file) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=None, + ) + + return DataInfo(dataloader=dataloader, sampler=None) diff --git a/a_cls/filter_eval_audio.py b/a_cls/filter_eval_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..30d146d7131eee6fd49c002dbfac6a8c9423a998 --- /dev/null +++ b/a_cls/filter_eval_audio.py @@ -0,0 +1,21 @@ +import json +import os.path +from tqdm import tqdm + +with open(r"G:\audioset\audioset\zip_audios\16k\eval.json", 'r') as f: + data = json.load(f)['data'] + +new_data = [] +total = 0 +success = 0 +for i in tqdm(data): + total += 1 + video_id = os.path.basename(i['wav']) + new_video_id = 'Y' + video_id + i['wav'] = new_video_id + if os.path.exists(f"G:/audioset/audioset/zip_audios/eval_segments/{i['wav']}") and not video_id.startswith('mW3S0u8bj58'): + new_data.append(i) + success += 1 +print(total, success, total-success) +with open(r"G:\audioset\audioset\zip_audios\16k\filter_eval.json", 'w') as f: + data = json.dump({'data': new_data}, f, indent=2) \ No newline at end of file diff --git a/a_cls/precision.py b/a_cls/precision.py new file mode 100644 index 0000000000000000000000000000000000000000..a63b92256518d13afd57261df1568e26b1622201 --- /dev/null +++ b/a_cls/precision.py @@ -0,0 +1,12 @@ +import torch +from contextlib import suppress + + +def get_autocast(precision): + if precision == 'amp': + return torch.cuda.amp.autocast + elif precision == 'amp_bfloat16' or precision == 'amp_bf16': + # amp_bfloat16 is more stable than amp float16 for clip training + return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return suppress diff --git a/a_cls/stats.py b/a_cls/stats.py new file mode 100644 index 0000000000000000000000000000000000000000..35a2418f6427c70a8ff9d214dd2f107d18e945d6 --- /dev/null +++ b/a_cls/stats.py @@ -0,0 +1,57 @@ +import numpy as np +from scipy import stats +from sklearn import metrics +import torch + +def d_prime(auc): + standard_normal = stats.norm() + d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) + return d_prime + +def calculate_stats(output, target): + """Calculate statistics including mAP, AUC, etc. + + Args: + output: 2d array, (samples_num, classes_num) + target: 2d array, (samples_num, classes_num) + + Returns: + stats: list of statistic of each class. + """ + + classes_num = target.shape[-1] + stats = [] + + # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet + acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) + + # Class-wise statistics + for k in range(classes_num): + + # Average precision + avg_precision = metrics.average_precision_score( + target[:, k], output[:, k], average=None) + + # AUC + auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None) + + # Precisions, recalls + (precisions, recalls, thresholds) = metrics.precision_recall_curve( + target[:, k], output[:, k]) + + # FPR, TPR + (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k]) + + save_every_steps = 1000 # Sample statistics to reduce size + dict = {'precisions': precisions[0::save_every_steps], + 'recalls': recalls[0::save_every_steps], + 'AP': avg_precision, + 'fpr': fpr[0::save_every_steps], + 'fnr': 1. - tpr[0::save_every_steps], + 'auc': auc, + # note acc is not class-wise, this is just to keep consistent with other metrics + 'acc': acc + } + stats.append(dict) + + return stats diff --git a/a_cls/util.py b/a_cls/util.py new file mode 100644 index 0000000000000000000000000000000000000000..7c48efe4f3609256d3d2230f7bf24b9c18ae5bbb --- /dev/null +++ b/a_cls/util.py @@ -0,0 +1,306 @@ +import math +import pickle +import numpy as np +import torch +import torch.nn as nn +import random +from collections import namedtuple + +def calc_recalls(S): + """ + Computes recall at 1, 5, and 10 given a similarity matrix S. + By convention, rows of S are assumed to correspond to images and columns are captions. + """ + assert(S.dim() == 2) + assert(S.size(0) == S.size(1)) + if isinstance(S, torch.autograd.Variable): + S = S.data + n = S.size(0) + A2I_scores, A2I_ind = S.topk(10, 0) + I2A_scores, I2A_ind = S.topk(10, 1) + A_r1 = AverageMeter() + A_r5 = AverageMeter() + A_r10 = AverageMeter() + I_r1 = AverageMeter() + I_r5 = AverageMeter() + I_r10 = AverageMeter() + for i in range(n): + A_foundind = -1 + I_foundind = -1 + for ind in range(10): + if A2I_ind[ind, i] == i: + I_foundind = ind + if I2A_ind[i, ind] == i: + A_foundind = ind + # do r1s + if A_foundind == 0: + A_r1.update(1) + else: + A_r1.update(0) + if I_foundind == 0: + I_r1.update(1) + else: + I_r1.update(0) + # do r5s + if A_foundind >= 0 and A_foundind < 5: + A_r5.update(1) + else: + A_r5.update(0) + if I_foundind >= 0 and I_foundind < 5: + I_r5.update(1) + else: + I_r5.update(0) + # do r10s + if A_foundind >= 0 and A_foundind < 10: + A_r10.update(1) + else: + A_r10.update(0) + if I_foundind >= 0 and I_foundind < 10: + I_r10.update(1) + else: + I_r10.update(0) + + recalls = {'A_r1':A_r1.avg, 'A_r5':A_r5.avg, 'A_r10':A_r10.avg, + 'I_r1':I_r1.avg, 'I_r5':I_r5.avg, 'I_r10':I_r10.avg} + #'A_meanR':A_meanR.avg, 'I_meanR':I_meanR.avg} + + return recalls + +def computeMatchmap(I, A): + assert(I.dim() == 3) + assert(A.dim() == 2) + D = I.size(0) + H = I.size(1) + W = I.size(2) + T = A.size(1) + Ir = I.view(D, -1).t() + matchmap = torch.mm(Ir, A) + matchmap = matchmap.view(H, W, T) + return matchmap + +def matchmapSim(M, simtype): + assert(M.dim() == 3) + if simtype == 'SISA': + return M.mean() + elif simtype == 'MISA': + M_maxH, _ = M.max(0) + M_maxHW, _ = M_maxH.max(0) + return M_maxHW.mean() + elif simtype == 'SIMA': + M_maxT, _ = M.max(2) + return M_maxT.mean() + else: + raise ValueError + +def sampled_margin_rank_loss(image_outputs, audio_outputs, nframes, margin=1., simtype='MISA'): + """ + Computes the triplet margin ranking loss for each anchor image/caption pair + The impostor image/caption is randomly sampled from the minibatch + """ + assert(image_outputs.dim() == 4) + assert(audio_outputs.dim() == 3) + n = image_outputs.size(0) + loss = torch.zeros(1, device=image_outputs.device, requires_grad=True) + for i in range(n): + I_imp_ind = i + A_imp_ind = i + while I_imp_ind == i: + I_imp_ind = np.random.randint(0, n) + while A_imp_ind == i: + A_imp_ind = np.random.randint(0, n) + nF = nframes[i] + nFimp = nframes[A_imp_ind] + anchorsim = matchmapSim(computeMatchmap(image_outputs[i], audio_outputs[i][:, 0:nF]), simtype) + Iimpsim = matchmapSim(computeMatchmap(image_outputs[I_imp_ind], audio_outputs[i][:, 0:nF]), simtype) + Aimpsim = matchmapSim(computeMatchmap(image_outputs[i], audio_outputs[A_imp_ind][:, 0:nFimp]), simtype) + A2I_simdif = margin + Iimpsim - anchorsim + if (A2I_simdif.data > 0).all(): + loss = loss + A2I_simdif + I2A_simdif = margin + Aimpsim - anchorsim + if (I2A_simdif.data > 0).all(): + loss = loss + I2A_simdif + loss = loss / n + return loss + +def compute_matchmap_similarity_matrix(image_outputs, audio_outputs, nframes, simtype='MISA'): + """ + Assumes image_outputs is a (batchsize, embedding_dim, rows, height) tensor + Assumes audio_outputs is a (batchsize, embedding_dim, 1, time) tensor + Returns similarity matrix S where images are rows and audios are along the columns + """ + assert(image_outputs.dim() == 4) + assert(audio_outputs.dim() == 3) + n = image_outputs.size(0) + S = torch.zeros(n, n, device=image_outputs.device) + for image_idx in range(n): + for audio_idx in range(n): + nF = max(1, nframes[audio_idx]) + S[image_idx, audio_idx] = matchmapSim(computeMatchmap(image_outputs[image_idx], audio_outputs[audio_idx][:, 0:nF]), simtype) + return S + +def compute_pooldot_similarity_matrix(image_outputs, audio_outputs, nframes): + """ + Assumes image_outputs is a (batchsize, embedding_dim, rows, height) tensor + Assumes audio_outputs is a (batchsize, embedding_dim, 1, time) tensor + Returns similarity matrix S where images are rows and audios are along the columns + S[i][j] is computed as the dot product between the meanpooled embeddings of + the ith image output and jth audio output + """ + assert(image_outputs.dim() == 4) + assert(audio_outputs.dim() == 4) + n = image_outputs.size(0) + imagePoolfunc = nn.AdaptiveAvgPool2d((1, 1)) + pooled_image_outputs = imagePoolfunc(image_outputs).squeeze(3).squeeze(2) + audioPoolfunc = nn.AdaptiveAvgPool2d((1, 1)) + pooled_audio_outputs_list = [] + for idx in range(n): + nF = max(1, nframes[idx]) + pooled_audio_outputs_list.append(audioPoolfunc(audio_outputs[idx][:, :, 0:nF]).unsqueeze(0)) + pooled_audio_outputs = torch.cat(pooled_audio_outputs_list).squeeze(3).squeeze(2) + S = torch.mm(pooled_image_outputs, pooled_audio_outputs.t()) + return S + +def one_imposter_index(i, N): + imp_ind = random.randint(0, N - 2) + if imp_ind == i: + imp_ind = N - 1 + return imp_ind + +def basic_get_imposter_indices(N): + imposter_idc = [] + for i in range(N): + # Select an imposter index for example i: + imp_ind = one_imposter_index(i, N) + imposter_idc.append(imp_ind) + return imposter_idc + +def semihardneg_triplet_loss_from_S(S, margin): + """ + Input: Similarity matrix S as an autograd.Variable + Output: The one-way triplet loss from rows of S to columns of S. Impostors are taken + to be the most similar point to the anchor that is still less similar to the anchor + than the positive example. + You would need to run this function twice, once with S and once with S.t(), + in order to compute the triplet loss in both directions. + """ + assert(S.dim() == 2) + assert(S.size(0) == S.size(1)) + N = S.size(0) + loss = torch.autograd.Variable(torch.zeros(1).type(S.data.type()), requires_grad=True) + # Imposter - ground truth + Sdiff = S - torch.diag(S).view(-1, 1) + eps = 1e-12 + # All examples less similar than ground truth + mask = (Sdiff < -eps).type(torch.LongTensor) + maskf = mask.type_as(S) + # Mask out all examples >= gt with minimum similarity + Sp = maskf * Sdiff + (1 - maskf) * torch.min(Sdiff).detach() + # Find the index maximum similar of the remaining + _, idc = Sp.max(dim=1) + idc = idc.data.cpu() + # Vector mask: 1 iff there exists an example < gt + has_neg = (mask.sum(dim=1) > 0).data.type(torch.LongTensor) + # Random imposter indices + random_imp_ind = torch.LongTensor(basic_get_imposter_indices(N)) + # Use hardneg if there exists an example < gt, otherwise use random imposter + imp_idc = has_neg * idc + (1 - has_neg) * random_imp_ind + # This could probably be vectorized too, but I haven't. + for i, imp in enumerate(imp_idc): + local_loss = Sdiff[i, imp] + margin + if (local_loss.data > 0).all(): + loss = loss + local_loss + loss = loss / N + return loss + +def sampled_triplet_loss_from_S(S, margin): + """ + Input: Similarity matrix S as an autograd.Variable + Output: The one-way triplet loss from rows of S to columns of S. Imposters are + randomly sampled from the columns of S. + You would need to run this function twice, once with S and once with S.t(), + in order to compute the triplet loss in both directions. + """ + assert(S.dim() == 2) + assert(S.size(0) == S.size(1)) + N = S.size(0) + loss = torch.autograd.Variable(torch.zeros(1).type(S.data.type()), requires_grad=True) + # Imposter - ground truth + Sdiff = S - torch.diag(S).view(-1, 1) + imp_ind = torch.LongTensor(basic_get_imposter_indices(N)) + # This could probably be vectorized too, but I haven't. + for i, imp in enumerate(imp_ind): + local_loss = Sdiff[i, imp] + margin + if (local_loss.data > 0).all(): + loss = loss + local_loss + loss = loss / N + return loss + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def adjust_learning_rate(base_lr, lr_decay, optimizer, epoch): + """Sets the learning rate to the initial LR decayed by 10 every lr_decay epochs""" + lr = base_lr * (0.1 ** (epoch // lr_decay)) + print('now learning rate changed to {:f}'.format(lr)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def adjust_learning_rate2(base_lr, lr_decay, optimizer, epoch): + """Sets the learning rate to the initial LR decayed by 10 every lr_decay epochs""" + for param_group in optimizer.param_groups: + cur_lr = param_group['lr'] + print('current learing rate is {:f}'.format(lr)) + lr = cur_lr * 0.1 + print('now learning rate changed to {:f}'.format(lr)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def load_progress(prog_pkl, quiet=False): + """ + load progress pkl file + Args: + prog_pkl(str): path to progress pkl file + Return: + progress(list): + epoch(int): + global_step(int): + best_epoch(int): + best_avg_r10(float): + """ + def _print(msg): + if not quiet: + print(msg) + + with open(prog_pkl, "rb") as f: + prog = pickle.load(f) + epoch, global_step, best_epoch, best_avg_r10, _ = prog[-1] + + _print("\nPrevious Progress:") + msg = "[%5s %7s %5s %7s %6s]" % ("epoch", "step", "best_epoch", "best_avg_r10", "time") + _print(msg) + return prog, epoch, global_step, best_epoch, best_avg_r10 + +def count_parameters(model): + return sum([p.numel() for p in model.parameters() if p.requires_grad]) + +PrenetConfig = namedtuple( + 'PrenetConfig', ['input_size', 'hidden_size', 'num_layers', 'dropout']) + +RNNConfig = namedtuple( + 'RNNConfig', + ['input_size', 'hidden_size', 'num_layers', 'dropout', 'residual']) \ No newline at end of file diff --git a/a_cls/zero_shot.py b/a_cls/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..57d709c77d4c9f0c565902d5bf4229df06a67265 --- /dev/null +++ b/a_cls/zero_shot.py @@ -0,0 +1,234 @@ +import logging +import os + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm + +from open_clip import get_input_dtype, get_tokenizer +from open_clip.factory import HF_HUB_PREFIX +from .precision import get_autocast +from .stats import calculate_stats, d_prime +from .zero_shot_classifier import build_zero_shot_classifier +from .zero_shot_metadata import CLASSNAMES, OPENAI_IMAGENET_TEMPLATES + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] + + +def run(model, classifier, dataloader, args): + autocast = get_autocast(args.precision) + input_dtype = get_input_dtype(args.precision) + + with torch.no_grad(): + top1, top5, n = 0., 0., 0. + for images, target in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(device=args.device, dtype=input_dtype) + images = images.unsqueeze(2) + target = target.to(args.device) + + with autocast(): + # predict + output = model(image=images) + image_features = output['image_features'] if isinstance(output, dict) else output[0] + logits = 100. * image_features @ classifier + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = (top1 / n) + top5 = (top5 / n) + return top1, top5 + + +def validate(audio_model, classifier, val_loader, args, epoch): + epoch = epoch - 1 ######################## + # switch to evaluate mode + audio_model.eval() + autocast = get_autocast(args.precision) + input_dtype = get_input_dtype(args.precision) + A_predictions = [] + A_targets = [] + A_loss = [] + with torch.no_grad(): + for i, (audio_input, labels) in enumerate(tqdm(val_loader)): + audio_input = audio_input.to(device=args.device, dtype=input_dtype) + + # compute output + with autocast(): + # predict + output = audio_model(image=audio_input) + image_features = output['image_features'] if isinstance(output, dict) else output[0] + logits = 100. * image_features @ classifier + audio_output = logits + + # audio_output = torch.sigmoid(audio_output) + predictions = audio_output.to('cpu').detach() + + A_predictions.append(predictions) + A_targets.append(labels) + + # compute the loss + labels = labels.to(args.device) + loss = nn.CrossEntropyLoss()(audio_output, torch.argmax(labels.long(), dim=1)) + A_loss.append(loss.to('cpu').detach()) + + audio_output = torch.cat(A_predictions) + target = torch.cat(A_targets) + loss = np.mean(A_loss) + stats = calculate_stats(audio_output, target) + + # save the prediction here + args.a_cls_output_dir = os.path.join(args.log_base_path, f'a_cls/{args.val_a_cls_data.lower()}') + os.makedirs(args.a_cls_output_dir, exist_ok=True) + if os.path.exists(args.a_cls_output_dir + '/predictions') == False: + os.mkdir(args.a_cls_output_dir + '/predictions') + np.savetxt(args.a_cls_output_dir + '/predictions/target.csv', target, delimiter=',') + np.savetxt(args.a_cls_output_dir + '/predictions/predictions_' + str(epoch) + '.csv', audio_output, + delimiter=',') + + valid_loss = loss + main_metrics = 'mAP' + metrics = {} + + if args.do_train: + # ensemble results + cum_stats = validate_ensemble(args, epoch) + cum_mAP = np.mean([stat['AP'] for stat in cum_stats]) + cum_mAUC = np.mean([stat['auc'] for stat in cum_stats]) + cum_acc = cum_stats[0]['acc'] + + mAP = np.mean([stat['AP'] for stat in stats]) + mAUC = np.mean([stat['auc'] for stat in stats]) + acc = stats[0]['acc'] + + middle_ps = [stat['precisions'][int(len(stat['precisions']) / 2)] for stat in stats] + middle_rs = [stat['recalls'][int(len(stat['recalls']) / 2)] for stat in stats] + average_precision = np.mean(middle_ps) + average_recall = np.mean(middle_rs) + + if main_metrics == 'mAP': + logging.info("mAP: {:.6f}".format(mAP)) + else: + logging.info("acc: {:.6f}".format(acc)) + logging.info("AUC: {:.6f}".format(mAUC)) + logging.info("Avg Precision: {:.6f}".format(average_precision)) + logging.info("Avg Recall: {:.6f}".format(average_recall)) + logging.info("d_prime: {:.6f}".format(d_prime(mAUC))) + logging.info("valid_loss: {:.6f}".format(valid_loss)) + + if args.do_train: + logging.info("cum_mAP: {:.6f}".format(cum_mAP)) + logging.info("cum_mAUC: {:.6f}".format(cum_mAUC)) + + if main_metrics == 'mAP': + metrics['mAP'] = float(mAP) + else: + metrics['acc'] = float(acc) + + metrics['mAUC'] = float(mAUC) + metrics['average_precision'] = float(average_precision) + metrics['average_recall'] = float(average_recall) + metrics['d_prime_mAUC'] = float(d_prime(mAUC)) + metrics['valid_loss'] = float(valid_loss) + + if args.do_train: + metrics['cum_mAP'] = float(cum_mAP) + metrics['cum_mAUC'] = float(cum_mAUC) + + return metrics + + +def validate_ensemble(args, epoch): + exp_dir = args.a_cls_output_dir + target = np.loadtxt(exp_dir + '/predictions/target.csv', delimiter=',') + if epoch == 0: + cum_predictions = np.loadtxt(exp_dir + '/predictions/predictions_0.csv', delimiter=',') + else: + cum_predictions = np.loadtxt(exp_dir + '/predictions/cum_predictions.csv', delimiter=',') * (epoch - 1) + predictions = np.loadtxt(exp_dir + '/predictions/predictions_' + str(epoch) + '.csv', delimiter=',') + cum_predictions = cum_predictions + predictions + # remove the prediction file to save storage space + os.remove(exp_dir + '/predictions/predictions_' + str(epoch - 1) + '.csv') + + cum_predictions = cum_predictions / (epoch + 1) + np.savetxt(exp_dir + '/predictions/cum_predictions.csv', cum_predictions, delimiter=',') + + stats = calculate_stats(cum_predictions, target) + return stats + + + + + + + + + +def zero_shot_eval(model, data, epoch, args): + temp_val_a_cls_data = args.val_a_cls_data + args.val_a_cls_data = list(data.keys()) + assert len(args.val_a_cls_data) == 1 + args.val_a_cls_data = args.val_a_cls_data[0] + + if args.val_a_cls_data not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + if args.distributed and not args.horovod: + model = model.module + + logging.info(f'Starting zero-shot {args.val_a_cls_data.upper()}.') + + logging.info('Building zero-shot classifier') + autocast = get_autocast(args.precision) + with autocast(): + tokenizer = get_tokenizer(HF_HUB_PREFIX+args.model, cache_dir=args.cache_dir) + # tokenizer = get_tokenizer("ViT-L-14") + classifier = build_zero_shot_classifier( + model, + tokenizer=tokenizer, + classnames=CLASSNAMES[args.val_a_cls_data], + templates=OPENAI_IMAGENET_TEMPLATES, + num_classes_per_batch=10, + device=args.device, + use_tqdm=True, + ) + + logging.info('Using classifier') + results = {} + if args.val_a_cls_data.lower() == 'audioset': + if args.val_a_cls_data in data: + stats = validate(model, classifier, data[args.val_a_cls_data].dataloader, args, epoch) + results.update(stats) + else: + if args.val_a_cls_data in data: + top1, top5 = run(model, classifier, data[args.val_a_cls_data].dataloader, args) + results[f'{args.val_a_cls_data}-zeroshot-val-top1'] = top1 + results[f'{args.val_a_cls_data}-zeroshot-val-top5'] = top5 + + logging.info(f'Finished zero-shot {args.val_a_cls_data.upper()}.') + + args.val_a_cls_data = temp_val_a_cls_data + return results + + + + + + + + + + + diff --git a/a_cls/zero_shot_classifier.py b/a_cls/zero_shot_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a5267cea4119994e30bb4830a6744cf25bdbaf --- /dev/null +++ b/a_cls/zero_shot_classifier.py @@ -0,0 +1,111 @@ +from functools import partial +from itertools import islice +from typing import Callable, List, Optional, Sequence, Union + +import torch +import torch.nn.functional as F + + +def batched(iterable, n): + """Batch data into lists of length *n*. The last batch may be shorter. + NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch + + +def build_zero_shot_classifier( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + num_classes_per_batch: Optional[int] = 10, + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names in batches + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + num_classes_per_batch: The number of classes to batch together in each forward, all if None + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + use_format = isinstance(templates[0], str) + num_templates = len(templates) + num_classes = len(classnames) + if use_tqdm: + import tqdm + num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) + iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) + else: + iter_wrap = iter + + def _process_batch(batch_classnames): + num_batch_classes = len(batch_classnames) + texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] + input_ids, attention_mask = tokenizer(texts) + input_ids, attention_mask = input_ids.to(device), attention_mask.to(device) + class_embeddings = F.normalize(model.encode_text(input_ids, attention_mask), dim=-1) + class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) + class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) + class_embeddings = class_embeddings.T + return class_embeddings + + with torch.no_grad(): + if num_classes_per_batch: + batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] + zeroshot_weights = torch.cat(batched_embeds, dim=1) + else: + zeroshot_weights = _process_batch(classnames) + return zeroshot_weights + + +def build_zero_shot_classifier_legacy( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names 1 by 1 + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + if use_tqdm: + import tqdm + iter_wrap = tqdm.tqdm + else: + iter_wrap = iter + + use_format = isinstance(templates[0], str) + + with torch.no_grad(): + zeroshot_weights = [] + for classname in iter_wrap(classnames): + texts = [template.format(classname) if use_format else template(classname) for template in templates] + texts = tokenizer(texts).to(device) # tokenize + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) + + return zeroshot_weights + diff --git a/a_cls/zero_shot_metadata.py b/a_cls/zero_shot_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..ee456ec1c345cc08aadda0343c844c0265c9e84e --- /dev/null +++ b/a_cls/zero_shot_metadata.py @@ -0,0 +1,184 @@ +import os + +import pandas as pd + +OPENAI_IMAGENET_TEMPLATES = ( + # lambda c: f'This is a sound of {c}.', + lambda c: f'a sound of {c}.', +) +# OPENAI_IMAGENET_TEMPLATES = ( +# lambda c: f'a bad sound of a {c}.', +# lambda c: f'a sound of many {c}.', +# lambda c: f'a sculpture of a {c}.', +# lambda c: f'a sound of the hard to see {c}.', +# lambda c: f'a low resolution sound of the {c}.', +# lambda c: f'a rendering of a {c}.', +# lambda c: f'graffiti of a {c}.', +# lambda c: f'a bad sound of the {c}.', +# lambda c: f'a cropped sound of the {c}.', +# lambda c: f'a tattoo of a {c}.', +# lambda c: f'the embroidered {c}.', +# lambda c: f'a sound of a hard to see {c}.', +# lambda c: f'a bright sound of a {c}.', +# lambda c: f'a sound of a clean {c}.', +# lambda c: f'a sound of a dirty {c}.', +# lambda c: f'a dark sound of the {c}.', +# lambda c: f'a drawing of a {c}.', +# lambda c: f'a sound of my {c}.', +# lambda c: f'the plastic {c}.', +# lambda c: f'a sound of the cool {c}.', +# lambda c: f'a close-up sound of a {c}.', +# lambda c: f'a black and white sound of the {c}.', +# lambda c: f'a painting of the {c}.', +# lambda c: f'a painting of a {c}.', +# lambda c: f'a pixelated sound of the {c}.', +# lambda c: f'a sculpture of the {c}.', +# lambda c: f'a bright sound of the {c}.', +# lambda c: f'a cropped sound of a {c}.', +# lambda c: f'a plastic {c}.', +# lambda c: f'a sound of the dirty {c}.', +# lambda c: f'a jpeg corrupted sound of a {c}.', +# lambda c: f'a blurry sound of the {c}.', +# lambda c: f'a sound of the {c}.', +# lambda c: f'a good sound of the {c}.', +# lambda c: f'a rendering of the {c}.', +# lambda c: f'a {c} in a video game.', +# lambda c: f'a sound of one {c}.', +# lambda c: f'a doodle of a {c}.', +# lambda c: f'a close-up sound of the {c}.', +# lambda c: f'a sound of a {c}.', +# lambda c: f'the origami {c}.', +# lambda c: f'the {c} in a video game.', +# lambda c: f'a sketch of a {c}.', +# lambda c: f'a doodle of the {c}.', +# lambda c: f'a origami {c}.', +# lambda c: f'a low resolution sound of a {c}.', +# lambda c: f'the toy {c}.', +# lambda c: f'a rendition of the {c}.', +# lambda c: f'a sound of the clean {c}.', +# lambda c: f'a sound of a large {c}.', +# lambda c: f'a rendition of a {c}.', +# lambda c: f'a sound of a nice {c}.', +# lambda c: f'a sound of a weird {c}.', +# lambda c: f'a blurry sound of a {c}.', +# lambda c: f'a cartoon {c}.', +# lambda c: f'art of a {c}.', +# lambda c: f'a sketch of the {c}.', +# lambda c: f'a embroidered {c}.', +# lambda c: f'a pixelated sound of a {c}.', +# lambda c: f'itap of the {c}.', +# lambda c: f'a jpeg corrupted sound of the {c}.', +# lambda c: f'a good sound of a {c}.', +# lambda c: f'a plushie {c}.', +# lambda c: f'a sound of the nice {c}.', +# lambda c: f'a sound of the small {c}.', +# lambda c: f'a sound of the weird {c}.', +# lambda c: f'the cartoon {c}.', +# lambda c: f'art of the {c}.', +# lambda c: f'a drawing of the {c}.', +# lambda c: f'a sound of the large {c}.', +# lambda c: f'a black and white sound of a {c}.', +# lambda c: f'the plushie {c}.', +# lambda c: f'a dark sound of a {c}.', +# lambda c: f'itap of a {c}.', +# lambda c: f'graffiti of the {c}.', +# lambda c: f'a toy {c}.', +# lambda c: f'itap of my {c}.', +# lambda c: f'a sound of a cool {c}.', +# lambda c: f'a sound of a small {c}.', +# lambda c: f'a tattoo of the {c}.', +# ) + +# a much smaller subset of above prompts +# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +SIMPLE_IMAGENET_TEMPLATES = ( + lambda c: f'itap of a {c}.', + lambda c: f'a bad sound of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a sound of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a sound of the small {c}.', +) + + +PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "class_labels_indices.csv") + + +CLASSNAMES = { + 'Audioset': tuple(pd.read_csv(PATH).values[:, 2]), + 'ESC50': ( + 'airplane', 'breathing', 'brushing teeth', 'can opening', 'car horn', 'cat', 'chainsaw', 'chirping birds', + 'church bells', 'clapping', 'clock alarm', 'clock tick', 'coughing', 'cow', 'crackling fire', 'crickets', + 'crow', 'crying baby', 'dog', 'door wood creaks', 'door wood knock', 'drinking sipping', 'engine', 'fireworks', + 'footsteps', 'frog', 'glass breaking', 'hand saw', 'helicopter', 'hen', 'insects', 'keyboard typing', + 'laughing', 'mouse click', 'pig', 'pouring water', 'rain', 'rooster', 'sea waves', 'sheep', 'siren', + 'sneezing', 'snoring', 'thunderstorm', 'toilet flush', 'train', 'vacuum cleaner', 'washing machine', + 'water drops', 'wind' + ), + 'VGGSound': ( + 'air conditioning noise', 'air horn', 'airplane', 'airplane flyby', 'alarm clock ringing', + 'alligators, crocodiles hissing', 'ambulance siren', 'arc welding', 'baby babbling', 'baby crying', + 'baby laughter', 'baltimore oriole calling', 'barn swallow calling', 'basketball bounce', + 'bathroom ventilation fan running', 'beat boxing', 'bee, wasp, etc. buzzing', 'bird chirping, tweeting', + 'bird squawking', 'bird wings flapping', 'black capped chickadee calling', 'blowtorch igniting', + 'bouncing on trampoline', 'bowling impact', 'bull bellowing', 'canary calling', 'cap gun shooting', + 'car engine idling', 'car engine knocking', 'car engine starting', 'car passing by', 'cat caterwauling', + 'cat growling', 'cat hissing', 'cat meowing', 'cat purring', 'cattle mooing', 'cattle, bovinae cowbell', + 'cell phone buzzing', 'chainsawing trees', 'cheetah chirrup', 'chicken clucking', 'chicken crowing', + 'child singing', 'child speech, kid speaking', 'children shouting', 'chimpanzee pant-hooting', + 'chinchilla barking', 'chipmunk chirping', 'chopping food', 'chopping wood', 'church bell ringing', + 'civil defense siren', 'cow lowing', 'coyote howling', 'cricket chirping', 'crow cawing', 'cuckoo bird calling', + 'cupboard opening or closing', 'cutting hair with electric trimmers', 'dinosaurs bellowing', 'disc scratching', + 'dog barking', 'dog baying', 'dog bow-wow', 'dog growling', 'dog howling', 'dog whimpering', + 'donkey, ass braying', 'door slamming', 'driving buses', 'driving motorcycle', 'driving snowmobile', + 'duck quacking', 'eagle screaming', 'eating with cutlery', 'electric grinder grinding', + 'electric shaver, electric razor shaving', 'elephant trumpeting', 'eletric blender running', 'elk bugling', + 'engine accelerating, revving, vroom', 'female singing', 'female speech, woman speaking', 'ferret dooking', + 'fire crackling', 'fire truck siren', 'fireworks banging', 'firing cannon', 'firing muskets', + 'fly, housefly buzzing', 'foghorn', 'footsteps on snow', 'forging swords', 'fox barking', 'francolin calling', + 'frog croaking', 'gibbon howling', 'goat bleating', 'golf driving', 'goose honking', 'hail', + 'hair dryer drying', 'hammering nails', 'heart sounds, heartbeat', 'hedge trimmer running', 'helicopter', + 'horse clip-clop', 'horse neighing', 'ice cracking', 'ice cream truck, ice cream van', 'lathe spinning', + 'lawn mowing', 'lighting firecrackers', 'lions growling', 'lions roaring', 'lip smacking', + 'machine gun shooting', 'magpie calling', 'male singing', 'male speech, man speaking', 'metronome', + 'missile launch', 'mosquito buzzing', 'motorboat, speedboat acceleration', 'mouse clicking', 'mouse pattering', + 'mouse squeaking', 'mynah bird singing', 'ocean burbling', 'opening or closing car doors', + 'opening or closing car electric windows', 'opening or closing drawers', 'orchestra', 'otter growling', + 'owl hooting', 'parrot talking', 'penguins braying', 'people babbling', 'people battle cry', + 'people belly laughing', 'people booing', 'people burping', 'people cheering', 'people clapping', + 'people coughing', 'people crowd', 'people eating', 'people eating apple', 'people eating crisps', + 'people eating noodle', 'people farting', 'people finger snapping', 'people gargling', 'people giggling', + 'people hiccup', 'people humming', 'people marching', 'people nose blowing', 'people running', + 'people screaming', 'people shuffling', 'people slapping', 'people slurping', 'people sneezing', + 'people sniggering', 'people sobbing', 'people whispering', 'people whistling', 'pheasant crowing', + 'pig oinking', 'pigeon, dove cooing', 'planing timber', 'plastic bottle crushing', 'playing accordion', + 'playing acoustic guitar', 'playing badminton', 'playing bagpipes', 'playing banjo', 'playing bass drum', + 'playing bass guitar', 'playing bassoon', 'playing bongo', 'playing bugle', 'playing castanets', + 'playing cello', 'playing clarinet', 'playing congas', 'playing cornet', 'playing cymbal', 'playing darts', + 'playing didgeridoo', 'playing djembe', 'playing double bass', 'playing drum kit', 'playing electric guitar', + 'playing electronic organ', 'playing erhu', 'playing flute', 'playing french horn', 'playing glockenspiel', + 'playing gong', 'playing guiro', 'playing hammond organ', 'playing harmonica', 'playing harp', + 'playing harpsichord', 'playing hockey', 'playing lacrosse', 'playing mandolin', 'playing marimba, xylophone', + 'playing oboe', 'playing piano', 'playing saxophone', 'playing shofar', 'playing sitar', 'playing snare drum', + 'playing squash', 'playing steel guitar, slide guitar', 'playing steelpan', 'playing synthesizer', + 'playing tabla', 'playing table tennis', 'playing tambourine', 'playing tennis', 'playing theremin', + 'playing timbales', 'playing timpani', 'playing trombone', 'playing trumpet', 'playing tuning fork', + 'playing tympani', 'playing ukulele', 'playing vibraphone', 'playing violin, fiddle', 'playing volleyball', + 'playing washboard', 'playing zither', 'police car (siren)', 'police radio chatter', 'popping popcorn', + 'printer printing', 'pumping water', 'race car, auto racing', 'railroad car, train wagon', 'raining', 'rapping', + 'reversing beeps', 'ripping paper', 'roller coaster running', 'rope skipping', 'rowboat, canoe, kayak rowing', + 'running electric fan', 'sailing', 'scuba diving', 'sea lion barking', 'sea waves', 'sharpen knife', + 'sheep bleating', 'shot football', 'singing bowl', 'singing choir', 'skateboarding', 'skidding', 'skiing', + 'sliding door', 'sloshing water', 'slot machine', 'smoke detector beeping', 'snake hissing', 'snake rattling', + 'splashing water', 'spraying water', 'squishing water', 'stream burbling', 'strike lighter', 'striking bowling', + 'striking pool', 'subway, metro, underground', 'swimming', 'tap dancing', 'tapping guitar', + 'telephone bell ringing', 'thunder', 'toilet flushing', 'tornado roaring', 'tractor digging', 'train horning', + 'train wheels squealing', 'train whistling', 'turkey gobbling', 'typing on computer keyboard', + 'typing on typewriter', 'underwater bubbling', 'using sewing machines', 'vacuum cleaner cleaning floors', + 'vehicle horn, car horn, honking', 'volcano explosion', 'warbler chirping', 'waterfall burbling', + 'whale calling', 'wind chime', 'wind noise', 'wind rustling leaves', 'wood thrush calling', + 'woodpecker pecking tree', 'writing on blackboard with chalk', 'yodelling', 'zebra braying' + ) + +} diff --git a/a_cls/zeroshot_cls.py b/a_cls/zeroshot_cls.py new file mode 100644 index 0000000000000000000000000000000000000000..5272c924666261266320e087e2ed15e2ae34c614 --- /dev/null +++ b/a_cls/zeroshot_cls.py @@ -0,0 +1,46 @@ + +import json +import logging +import os +from training.distributed import is_master +from .zero_shot import zero_shot_eval + +try: + import wandb +except ImportError: + wandb = None + + + +def evaluate_a_cls(model, data, epoch, args, tb_writer=None): + metrics = {} + if not is_master(args): + return metrics + model.eval() + + zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + metrics.update(zero_shot_metrics) + + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) + ) + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/a_cls/{args.val_a_cls_data[0].lower()}/{name}", val, epoch) + args.a_cls_output_dir = os.path.join(args.log_base_path, f'a_cls/{args.val_a_cls_data[0].lower()}') + os.makedirs(args.a_cls_output_dir, exist_ok=True) + with open(os.path.join(args.a_cls_output_dir, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, 'Please install wandb.' + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, 'epoch': epoch}) + + return metrics diff --git a/al_ret/data_dataloaders.py b/al_ret/data_dataloaders.py new file mode 100644 index 0000000000000000000000000000000000000000..4f31aa0686fcd6b5a678f4d8a76c64b662001161 --- /dev/null +++ b/al_ret/data_dataloaders.py @@ -0,0 +1,28 @@ +import argparse +import torch +from torch.utils.data import DataLoader + +from data.build_datasets import get_data +from data.process_audio import get_audio_transform +from .dataloader_msrvtt_retrieval import MSRVTT_DataLoader + +def dataloader_msrvtt_test(args, tokenizer, subset="test"): + msrvtt_testset = MSRVTT_DataLoader( + csv_path=args.val_csv, + features_path=args.features_path, + max_words=args.max_words, + tokenizer=tokenizer, + transform=get_audio_transform(args) + ) + dataloader_msrvtt = DataLoader( + msrvtt_testset, + batch_size=args.batch_size_val, + num_workers=args.num_thread_reader, + shuffle=False, + drop_last=False, + ) + return dataloader_msrvtt, len(msrvtt_testset) + + +DATALOADER_DICT = {} +DATALOADER_DICT["msrvtt"] = {"val":dataloader_msrvtt_test, "test":None} diff --git a/al_ret/dataloader_msrvtt_retrieval.py b/al_ret/dataloader_msrvtt_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..903f6aea602eaec8d57d2a77293478c2bf7cb1db --- /dev/null +++ b/al_ret/dataloader_msrvtt_retrieval.py @@ -0,0 +1,114 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals +from __future__ import print_function + +import os + +import torchaudio +from torch.utils.data import Dataset +import numpy as np +import pandas as pd +from collections import defaultdict +import json +import random + +from torchvision.io import read_video + + +class MSRVTT_DataLoader(Dataset): + """MSRVTT dataset loader.""" + def __init__( + self, + csv_path, + features_path, + tokenizer, + transform=77, + max_words=30, + ): + self.data = pd.read_csv(csv_path) + self.features_path = features_path + self.max_words = max_words + self.tokenizer = tokenizer + + # self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) + self.transform = transform + self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", + "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} + + + + def __len__(self): + return len(self.data) + + def _get_text(self, video_id, sentence): + choice_video_ids = [video_id] + n_caption = len(choice_video_ids) + + k = n_caption + pairs_text = np.zeros((k, self.max_words), dtype=np.long) + pairs_mask = np.zeros((k, self.max_words), dtype=np.long) + pairs_segment = np.zeros((k, self.max_words), dtype=np.long) + + for i, video_id in enumerate(choice_video_ids): + # words = self.tokenizer.tokenize(sentence) + # + # words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words + # total_length_with_CLS = self.max_words - 1 + # if len(words) > total_length_with_CLS: + # words = words[:total_length_with_CLS] + # words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] + # + # input_ids = self.tokenizer.convert_tokens_to_ids(words) + # input_mask = [1] * len(input_ids) + # segment_ids = [0] * len(input_ids) + + + output = self.tokenizer(sentence) + + input_ids = output[0].squeeze() + input_mask = output[1].squeeze() + segment_ids = [0] * len(input_ids) + + + while len(input_ids) < self.max_words: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == self.max_words + assert len(input_mask) == self.max_words + assert len(segment_ids) == self.max_words + + pairs_text[i] = np.array(input_ids) + pairs_mask[i] = np.array(input_mask) + pairs_segment[i] = np.array(segment_ids) + + return pairs_text, pairs_mask, pairs_segment, choice_video_ids + + def _get_rawvideo(self, choice_video_ids): + # Pair x L x T x 3 x H x W + audio = np.zeros((len(choice_video_ids), 3, + self.transform.num_mel_bins, self.transform.target_length), dtype=np.float) + assert len(choice_video_ids) == 1 + for i, video_id in enumerate(choice_video_ids): + # Individual for YoucokII dataset, due to it video format + video_path = os.path.join(self.features_path, "{}.mp4".format(video_id)) + if os.path.exists(video_path) is False: + video_path = video_path.replace(".mp4", ".webm") + + # raw_video_data = self.rawVideoExtractor.get_video_data(video_path) + # _, raw_audio_data, info = read_video(video_path, pts_unit='sec') + # audio_data = self.transform((raw_audio_data, info['audio_fps'])) + + audio_data = torchaudio.load(video_path.replace('mp4', 'wav')) + audio_data = self.transform(audio_data) + # audio[i] = audio_data + return audio_data + + def __getitem__(self, idx): + video_id = self.data['video_id'].values[idx] + sentence = self.data['sentence'].values[idx] + + pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, sentence) + audio_data = self._get_rawvideo(choice_video_ids) + return audio_data, pairs_text, pairs_mask diff --git a/al_ret/datasets.py b/al_ret/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..304e7dd352b8269874e261c6c140619fc4ac4b95 --- /dev/null +++ b/al_ret/datasets.py @@ -0,0 +1,137 @@ +import logging +import os.path +import random + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset +from data.build_datasets import DataInfo +from open_clip import get_input_dtype, get_tokenizer +from open_clip.factory import HF_HUB_PREFIX +from data.process_audio import get_audio_transform, torchaudio_loader + +class Audiocaps_dataset(Dataset): + def __init__(self, data_path, transform, loader, tokenizer): + super(Audiocaps_dataset, self).__init__() + self.audio_root = data_path + raw_meta = pd.read_csv(f'{self.audio_root}/audiocaps_test.tsv', delimiter='\t').values + audio_ids = list(set(raw_meta[:, 1].tolist())) + captions = {} + for i in raw_meta: + if captions.get(i[1], None) is None: + captions[i[1]] = [i[2]] + else: + captions[i[1]] = captions[i[1]] + [i[2]] + # captions = {i[:1][0]: i[1:].tolist() for i in raw_meta} + + + self.sample_len = 0 + self.sentences_dict = {} + self.cut_off_points = [] + for audio_id in audio_ids: + assert audio_id in captions + for cap in captions[audio_id]: + cap_txt = cap + self.sentences_dict[len(self.sentences_dict)] = (audio_id[10:], cap_txt) + self.cut_off_points.append(len(self.sentences_dict)) + + self.multi_sentence_per_audio = True # !!! important tag for eval + if self.multi_sentence_per_audio: + # if self.subset == "val" or self.subset == "test": + self.sentence_num = len(self.sentences_dict) + self.audio_num = len(audio_ids) + assert len(self.cut_off_points) == self.audio_num + print("Sentence number: {}".format(self.sentence_num)) + print("Video number: {}".format(self.audio_num)) + + self.sample_len = len(self.sentences_dict) + + self.transform = transform + self.torchaudio_loader = loader + self.tokenizer = tokenizer + + def __len__(self): + return self.sample_len + + def __getitem__(self, idx): + audiocap_id, caption = self.sentences_dict[idx] + + audio_path = os.path.join(self.audio_root, audiocap_id) + audio = self.torchaudio_loader(audio_path) + audio_data = self.transform(audio) + + input_ids, attention_mask = self.tokenizer(caption) + return audio_data, input_ids.squeeze(), attention_mask.squeeze() + + +class Clotho_dataset(Dataset): + def __init__(self, data_path, transform, loader, tokenizer): + super(Clotho_dataset, self).__init__() + self.audio_root = data_path + raw_meta = pd.read_csv(f'{self.audio_root}/CLOTHO_retrieval_dataset/clotho_captions_evaluation.csv').values + audio_ids = raw_meta[:, 0].tolist() + captions = {i[:1][0]: i[1:].tolist() for i in raw_meta} + # self.meta = pd.DataFrame(np.vstack([np.vstack([raw_meta[:, 0], raw_meta[:, i]]).T for i in range(1, 6)]), + # columns=['uniq_id', 'text']) + + self.sample_len = 0 + self.sentences_dict = {} + self.cut_off_points = [] + for audio_id in audio_ids: + assert audio_id in captions + for cap in captions[audio_id]: + cap_txt = cap + self.sentences_dict[len(self.sentences_dict)] = (audio_id, cap_txt) + self.cut_off_points.append(len(self.sentences_dict)) + + self.multi_sentence_per_audio = True # !!! important tag for eval + if self.multi_sentence_per_audio: + # if self.subset == "val" or self.subset == "test": + self.sentence_num = len(self.sentences_dict) + self.audio_num = len(audio_ids) + assert len(self.cut_off_points) == self.audio_num + print("Sentence number: {}".format(self.sentence_num)) + print("Video number: {}".format(self.audio_num)) + + self.sample_len = len(self.sentences_dict) + + self.transform = transform + self.torchaudio_loader = loader + self.tokenizer = tokenizer + + def __len__(self): + return self.sample_len + + def __getitem__(self, idx): + audiocap_id, caption = self.sentences_dict[idx] + # audiocap_id = self.meta['uniq_id'][idx] + audio_path = os.path.join(self.audio_root, f'evaluation/{audiocap_id}') + audio = self.torchaudio_loader(audio_path) + audio_data = self.transform(audio) + + # caption = self.meta['text'][idx] + input_ids, attention_mask = self.tokenizer(caption) + return audio_data, input_ids.squeeze(), attention_mask.squeeze() + +def get_audio_dataset(args): + data_path = args.audio_data_path + transform = get_audio_transform(args) + tokenizer = get_tokenizer(HF_HUB_PREFIX+args.model, cache_dir=args.cache_dir) + + if args.val_al_ret_data.lower() == 'audiocaps': + dataset = Audiocaps_dataset(data_path, transform=transform, loader=torchaudio_loader, tokenizer=tokenizer) + elif args.val_al_ret_data.lower() == 'clotho': + dataset = Clotho_dataset(data_path, transform=transform, loader=torchaudio_loader, tokenizer=tokenizer) + else: + raise ValueError(f'unsupport dataset {args.val_al_ret_data}') + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + shuffle=False, + drop_last=False, + ) + + return dataloader diff --git a/al_ret/metrics.py b/al_ret/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..708f8c9aec43a3b4b768f6a22739a268d8c38a16 --- /dev/null +++ b/al_ret/metrics.py @@ -0,0 +1,70 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals +from __future__ import print_function + +import numpy as np +import torch + +def compute_metrics(x): + sx = np.sort(-x, axis=1) + d = np.diag(-x) + d = d[:, np.newaxis] + ind = sx - d + ind = np.where(ind == 0) + ind = ind[1] + metrics = {} + metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) + metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) + metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) + metrics['MR'] = np.median(ind) + 1 + metrics["MedianR"] = metrics['MR'] + metrics["MeanR"] = np.mean(ind) + 1 + # metrics["cols"] = [int(i) for i in list(ind)] + return metrics + +def print_computed_metrics(metrics): + r1 = metrics['R1'] + r5 = metrics['R5'] + r10 = metrics['R10'] + mr = metrics['MR'] + print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr)) + +# below two functions directly come from: https://github.com/Deferf/Experiments +def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10]): + if not torch.is_tensor(sim_tensor): + sim_tensor = torch.tensor(sim_tensor) + + # Permute sim_tensor so it represents a sequence of text-video similarity matrices. + # Then obtain the double argsort to position the rank on the diagonal + stacked_sim_matrices = sim_tensor.permute(1, 0, 2) + first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True) + second_argsort = torch.argsort(first_argsort, dim = -1, descending= False) + + # Extracts ranks i.e diagonals + ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2)) + + # Now we need to extract valid ranks, as some belong to inf padding values + permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2)) + mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data)) + valid_ranks = ranks[mask] + # A quick dimension check validates our results, there may be other correctness tests pending + # Such as dot product localization, but that is for other time. + #assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict]) + if not torch.is_tensor(valid_ranks): + valid_ranks = torch.tensor(valid_ranks) + results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k} + results["MedianR"] = float(torch.median(valid_ranks + 1)) + results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1)) + results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1)) + results['MR'] = results["MedianR"] + return results + +def tensor_video_to_text_sim(sim_tensor): + if not torch.is_tensor(sim_tensor): + sim_tensor = torch.tensor(sim_tensor) + # Code to avoid nans + sim_tensor[sim_tensor != sim_tensor] = float('-inf') + # Forms a similarity matrix for use with rank at k + values, _ = torch.max(sim_tensor, dim=1, keepdim=True) + return torch.squeeze(values).T diff --git a/al_ret/precision.py b/al_ret/precision.py new file mode 100644 index 0000000000000000000000000000000000000000..a63b92256518d13afd57261df1568e26b1622201 --- /dev/null +++ b/al_ret/precision.py @@ -0,0 +1,12 @@ +import torch +from contextlib import suppress + + +def get_autocast(precision): + if precision == 'amp': + return torch.cuda.amp.autocast + elif precision == 'amp_bfloat16' or precision == 'amp_bf16': + # amp_bfloat16 is more stable than amp float16 for clip training + return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return suppress diff --git a/al_ret/retrieval.py b/al_ret/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..8fca37f2d9c6c1f74cf9102cde01f01d8448fa61 --- /dev/null +++ b/al_ret/retrieval.py @@ -0,0 +1,266 @@ + +import json +import logging +import os +import numpy as np +import torch + +from training.distributed import is_master +from .zero_shot import zero_shot_eval +from .util import parallel_apply +from .metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim +from torch.nn import functional as F +try: + import wandb +except ImportError: + wandb = None + + +# +# def evaluate_al_ret(model, data, epoch, args, tb_writer=None): +# metrics = {} +# if not is_master(args): +# return metrics +# model.eval() +# +# zero_shot_metrics = zero_shot_eval(model, data, epoch, args) +# metrics.update(zero_shot_metrics) +# +# if not metrics: +# return metrics +# +# logging.info( +# f"Eval Epoch: {epoch} " +# + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) +# ) +# +# if args.save_logs: +# for name, val in metrics.items(): +# if tb_writer is not None: +# tb_writer.add_scalar(f"val/al_ret/{name}", val, epoch) +# args.al_ret_output_dir = os.path.join(args.log_base_path, 'al_ret') +# os.makedirs(args.al_ret_output_dir, exist_ok=True) +# with open(os.path.join(args.al_ret_output_dir, "results.jsonl"), "a+") as f: +# f.write(json.dumps(metrics)) +# f.write("\n") +# +# if args.wandb: +# assert wandb is not None, 'Please install wandb.' +# for name, val in metrics.items(): +# wandb.log({f"val/{name}": val, 'epoch': epoch}) +# +# return metrics + + + +def _run_on_single_gpu(model, + # batch_list_t, batch_list_v, + batch_sequence_output_list, batch_visual_output_list): + sim_matrix = [] + for idx1 in range(len(batch_sequence_output_list)): + # input_mask, segment_ids, *_tmp = b1 + sequence_output = batch_sequence_output_list[idx1] + each_row = [] + for idx2 in range(len(batch_visual_output_list)): + # video_mask, *_tmp = b2 + visual_output = batch_visual_output_list[idx2] + # b1b2_logits, *_tmp = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask, + # loose_type=model.loose_type) + # logging.info(f"{model.logit_scale.device}, {visual_output.device}, {sequence_output.device}") + b1b2_logits = model.logit_scale * sequence_output @ visual_output.T + # print(model.logit_scale.device, visual_output.device, sequence_output.device) + # logging.info(f"{b1b2_logits.shape}, {b1b2_logits.device}") + b1b2_logits = b1b2_logits.cpu().detach().numpy() + each_row.append(b1b2_logits) + each_row = np.concatenate(tuple(each_row), axis=-1) + sim_matrix.append(each_row) + return sim_matrix + +def evaluate_al_ret(model, data, epoch, args, tb_writer=None): + if is_master(args) and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): + # print(data) + val_al_ret_data = list(data.keys()) + # print(val_vl_ret_data) + assert len(val_al_ret_data) == 1 + val_al_ret_data = val_al_ret_data[0] + test_dataloader = data[val_al_ret_data] + # print(len(test_dataloader)) + # print(len(test_dataloader)) + # print(len(test_dataloader)) + # print(len(test_dataloader)) + device = model.device + n_gpu = torch.cuda.device_count() + logging.info(f"\nEval Epoch: {epoch}, eval Audio-Text Retrieval under {val_al_ret_data.upper()} test data") + if hasattr(model, 'module'): + model = model.module.to(device) + else: + model = model.to(device) + # ################################################################# + ## below variables are used to multi-sentences retrieval + # multi_sentence_: important tag for eval + # cut_off_points: used to tag the label when calculate the metric + # sentence_num: used to cut the sentence representation + # video_num: used to cut the video representation + # ################################################################# + multi_sentence_ = False + cut_off_points_, sentence_num_, video_num_ = [], -1, -1 + if hasattr(test_dataloader.dataset, 'multi_sentence_per_audio') and test_dataloader.dataset.multi_sentence_per_audio: + # if False: + multi_sentence_ = True + cut_off_points_ = test_dataloader.dataset.cut_off_points + sentence_num_ = test_dataloader.dataset.sentence_num + video_num_ = test_dataloader.dataset.audio_num + cut_off_points_ = [itm - 1 for itm in cut_off_points_] + + if multi_sentence_: + print("Eval under the multi-sentence per audio clip setting.") + print("sentence num: {}, video num: {}".format(sentence_num_, video_num_)) + logging.info("Eval under the multi-sentence per audio clip setting.") + logging.info("sentence num: {}, video num: {}".format(sentence_num_, video_num_)) + + model.eval() + with torch.no_grad(): + # batch_list_t = [] + # batch_list_v = [] + batch_sequence_output_list, batch_visual_output_list = [], [] + total_video_num = 0 + + # ---------------------------- + # 1. cache the features + # ---------------------------- + for bid, batch in enumerate(test_dataloader): + # batch = tuple(t.to(device) for t in batch) + video, input_ids, attention_mask = batch + # print(input_ids.shape, video.shape, video.dtype) + input_ids = input_ids.squeeze().to(device) + attention_mask = attention_mask.squeeze().to(device) + # video = video.squeeze().permute(0, 2, 1, 3, 4).float().to(device) + video = video.float().to(device) + + + + # print(input_ids.shape, video.shape, video.dtype) + # print(input_ids.shape, video.shape) + if multi_sentence_: + # multi-sentences retrieval means: one clip has two or more descriptions. + b, *_t = video.shape + sequence_output = model.encode_text(input_ids, attention_mask) + # logging.info(f'multi: {sequence_output.shape}') + # sequence_output = model.get_sequence_output(input_ids, segment_ids, input_mask) + batch_sequence_output_list.append(sequence_output) + # batch_list_t.append((input_mask, segment_ids,)) + # 0 16 + s_, e_ = total_video_num, total_video_num + b + filter_inds = [itm - s_ for itm in cut_off_points_ if itm >= s_ and itm < e_] # cut_off_points_ [0 4 9 14] + + if len(filter_inds) > 0: + # video, video_mask = video[filter_inds, ...], video_mask[filter_inds, ...] + # print('before', video.shape) + video = video[filter_inds, ...] + # print('after', video.shape) + # visual_output = model.get_visual_output(video, video_mask) + visual_output = model.encode_image(video) + batch_visual_output_list.append(visual_output) + # batch_list_v.append((video_mask,)) + total_video_num += b + else: + sequence_output = model.encode_text(input_ids, attention_mask) + visual_output = model.encode_image(video) + # sequence_output, visual_output = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask) + + batch_sequence_output_list.append(sequence_output) + # batch_list_t.append((input_mask, segment_ids,)) + + batch_visual_output_list.append(visual_output) + # batch_list_v.append((video_mask,)) + + print(f"Process {val_al_ret_data.upper()}: {bid}/{len(test_dataloader)}\r", end='') + # ---------------------------------- + # 2. calculate the similarity + # ---------------------------------- + n_gpu = torch.cuda.device_count() + if n_gpu > 1: + # print('n_gpu > 1') + device_ids = list(range(n_gpu)) + # print('device_ids', device_ids) + batch_t_output_splits = [] + batch_v_output_splits = [] + bacth_len = len(batch_sequence_output_list) + # print(bacth_len) + split_len = (bacth_len + n_gpu - 1) // n_gpu + for dev_id in device_ids: + s_, e_ = dev_id * split_len, (dev_id + 1) * split_len + if dev_id == 0: + + batch_t_output_splits.append(batch_sequence_output_list[s_:e_]) + batch_v_output_splits.append(batch_visual_output_list) + # print(len(batch_sequence_output_list[s_:e_]), len(batch_visual_output_list)) + else: + devc = torch.device('cuda:{}'.format(str(dev_id))) + + devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]] + batch_t_output_splits.append(devc_batch_list) + devc_batch_list = [b.to(devc) for b in batch_visual_output_list] + batch_v_output_splits.append(devc_batch_list) + # print(len(devc_batch_list), len(devc_batch_list)) + parameters_tuple_list = [( + batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids] + parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids) + sim_matrix = [] + for idx in range(len(parallel_outputs)): + sim_matrix += parallel_outputs[idx] + sim_matrix = np.concatenate(tuple(sim_matrix), axis=0) + else: + sim_matrix = _run_on_single_gpu(model, + # batch_list_t, batch_list_v, + batch_sequence_output_list, batch_visual_output_list) + sim_matrix = np.concatenate(tuple(sim_matrix), axis=0) + ##################################################################### + if multi_sentence_: + + logging.info(f"{val_al_ret_data.upper()} before reshape, sim matrix size: {sim_matrix.shape}") + cut_off_points2len_ = [itm + 1 for itm in cut_off_points_] + max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)]) + sim_matrix_new = [] + for s_, e_ in zip([0] + cut_off_points2len_[:-1], cut_off_points2len_): + sim_matrix_new.append(np.concatenate((sim_matrix[s_:e_], + np.full((max_length-e_+s_, sim_matrix.shape[1]), -np.inf)), axis=0)) + sim_matrix = np.stack(tuple(sim_matrix_new), axis=0) + logging.info(f"{val_al_ret_data.upper()} after reshape, sim matrix size: {sim_matrix.shape}") + + tv_metrics = tensor_text_to_video_metrics(sim_matrix) + # vt_metrics = compute_metrics(tensor_video_to_text_sim(sim_matrix)) + else: + logging.info(f"{val_al_ret_data.upper()} sim matrix size: {sim_matrix.shape[0]}, {sim_matrix.shape[1]}") + t2v_sim_matrix = torch.from_numpy(sim_matrix).cuda() + # t2v_sim_matrix = t2v_sim_matrix * F.softmax(t2v_sim_matrix*10, dim=0) * len(t2v_sim_matrix) + tv_metrics = compute_metrics(t2v_sim_matrix.cpu().numpy()) + + + # vt_metrics = compute_metrics(t2v_sim_matrix.T.cpu().numpy()) + + logging.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0]))) + + logging.info(f"{val_al_ret_data.upper()} Text-to-Audio:") + logging.info('\t>>> R@1: {:.1f} - R@5: {:.1f} - R@10: {:.1f} - Median R: {:.1f} - Mean R: {:.1f}'. + format(tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR'])) + # logging.info(f"{val_al_ret_data.upper()} Text-to-Audio:") + # logging.info('\t>>> V2T$R@1: {:.1f} - V2T$R@5: {:.1f} - V2T$R@10: {:.1f} - V2T$Median R: {:.1f} - V2T$Mean R: {:.1f}'. + # format(vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR'])) + + + if args.save_logs: + for name, val in tv_metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/al_ret/{val_al_ret_data}/t2a/{name}", val, epoch) + # for name, val in vt_metrics.items(): + # if tb_writer is not None: + # tb_writer.add_scalar(f"val/al_ret/{val_al_ret_data}/v2t/{name}", val, epoch) + + args.al_ret_output_dir = os.path.join(args.log_base_path, f'al_ret/{val_al_ret_data}') + os.makedirs(args.al_ret_output_dir, exist_ok=True) + with open(os.path.join(args.al_ret_output_dir, "results.jsonl"), "a+") as f: + f.write(json.dumps({'t2a': tv_metrics})) + f.write("\n") + # f.write(json.dumps({'v2t': vt_metrics})) + # f.write("\n") \ No newline at end of file diff --git a/al_ret/util.py b/al_ret/util.py new file mode 100644 index 0000000000000000000000000000000000000000..6b11cd4ea86d93304a882bebda2f5128bec7eb4d --- /dev/null +++ b/al_ret/util.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import threading +from torch._utils import ExceptionWrapper +import logging + +def get_a_var(obj): + if isinstance(obj, torch.Tensor): + return obj + + if isinstance(obj, list) or isinstance(obj, tuple): + for result in map(get_a_var, obj): + if isinstance(result, torch.Tensor): + return result + if isinstance(obj, dict): + for result in map(get_a_var, obj.items()): + if isinstance(result, torch.Tensor): + return result + return None + +def parallel_apply(fct, model, inputs, device_ids): + modules = nn.parallel.replicate(model, device_ids) + assert len(modules) == len(inputs) + lock = threading.Lock() + results = {} + grad_enabled = torch.is_grad_enabled() + + def _worker(i, module, input): + torch.set_grad_enabled(grad_enabled) + device = get_a_var(input).get_device() + try: + with torch.cuda.device(device): + # this also avoids accidental slicing of `input` if it is a Tensor + if not isinstance(input, (list, tuple)): + input = (input,) + output = fct(module, *input) + with lock: + results[i] = output + except Exception: + with lock: + results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) + + if len(modules) > 1: + threads = [threading.Thread(target=_worker, args=(i, module, input)) + for i, (module, input) in enumerate(zip(modules, inputs))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + _worker(0, modules[0], inputs[0]) + + outputs = [] + for i in range(len(inputs)): + output = results[i] + if isinstance(output, ExceptionWrapper): + output.reraise() + outputs.append(output) + return outputs + +def get_logger(filename=None): + logger = logging.getLogger('logger') + logger.setLevel(logging.DEBUG) + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) + if filename is not None: + handler = logging.FileHandler(filename) + handler.setLevel(logging.DEBUG) + handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) + logging.getLogger().addHandler(handler) + return logger \ No newline at end of file diff --git a/al_ret/zero_shot.py b/al_ret/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..16f32e6c31b7affa3910dacbac24fa4b54ae9fc1 --- /dev/null +++ b/al_ret/zero_shot.py @@ -0,0 +1,91 @@ +import logging + +import numpy as np +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from open_clip import get_input_dtype, get_tokenizer +from open_clip.factory import HF_HUB_PREFIX +from .precision import get_autocast + +def compute_metrics(x): + sx = np.sort(-x, axis=1) + d = np.diag(-x) + d = d[:, np.newaxis] + ind = sx - d + ind = np.where(ind == 0) + ind = ind[1] + metrics = {} + metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) + metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) + metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) + metrics['MR'] = np.median(ind) + 1 + metrics["MedianR"] = metrics['MR'] + metrics["MeanR"] = np.mean(ind) + 1 + # metrics["cols"] = [int(i) for i in list(ind)] + return metrics + + +def _run_on_single_gpu(model, batch_sequence_output_list, batch_visual_output_list): + sim_matrix = [] + logit_scale = model.logit_scale.exp() + for idx1, sequence_output in enumerate(batch_sequence_output_list): + each_row = [] + for idx2, visual_output in enumerate(batch_visual_output_list): + b1b2_logits = logit_scale * torch.matmul(sequence_output, visual_output.t()) + b1b2_logits = b1b2_logits.cpu().detach().numpy() + each_row.append(b1b2_logits) + each_row = np.concatenate(tuple(each_row), axis=-1) + sim_matrix.append(each_row) + return sim_matrix + +def run(model, dataloader, args): + autocast = get_autocast(args.precision) + input_dtype = get_input_dtype(args.precision) + + with torch.no_grad(): + sequence_output_list, visual_output_list = [], [] + for images, input_ids, attention_mask in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(device=args.device, dtype=input_dtype) + images = images.unsqueeze(2) + input_ids = input_ids.squeeze().to(args.device) + attention_mask = attention_mask.squeeze().to(args.device) + + with autocast(): + # predict + sequence_output = model.encode_text(input_ids, attention_mask) + visual_output = model.encode_image(images) + sequence_output_list.append(sequence_output) + visual_output_list.append(visual_output) + sim_matrix = _run_on_single_gpu(model, sequence_output_list, visual_output_list) + sim_matrix = np.concatenate(tuple(sim_matrix), axis=0) + return sim_matrix + + +def zero_shot_eval(model, data, epoch, args): + temp_val_al_ret_data = args.val_al_ret_data + args.val_al_ret_data = list(data.keys()) + assert len(args.val_al_ret_data) == 1 + args.val_al_ret_data = args.val_al_ret_data[0] + + if args.val_al_ret_data not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + if args.distributed and not args.horovod: + model = model.module + + logging.info(f'Starting zero-shot {args.val_al_ret_data.upper()}.') + + results = {} + if args.val_al_ret_data in data: + logit_matrix = run(model, data[args.val_al_ret_data].dataloader, args) + results = compute_metrics(logit_matrix) + + logging.info(f'Finished zero-shot {args.val_al_ret_data.upper()}.') + + args.val_al_ret_data = temp_val_al_ret_data + return results diff --git a/assets/audio/0.wav b/assets/audio/0.wav new file mode 100644 index 0000000000000000000000000000000000000000..e8073ee997d577abf4b3ae328e0a067f92e65ed7 --- /dev/null +++ b/assets/audio/0.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38aff33c1d6e68dfa0bd310d1e4cff10df4ac3642b3cc96637fab2a0e74b64a9 +size 327788 diff --git a/assets/audio/1.wav b/assets/audio/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..89b312fa3a16a6044d01b13e22e43eff638a89e3 --- /dev/null +++ b/assets/audio/1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25400620506fbc099ee78fe4d31379a218a264ddcfcfe658e4ba9c2255fc6c01 +size 327788 diff --git a/assets/demo.png b/assets/demo.png new file mode 100644 index 0000000000000000000000000000000000000000..bc3430dd0bf6bae5a583496ea3fe7bd4498d73f0 --- /dev/null +++ b/assets/demo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34d7015339c050253fd4044c26bb1b82b423e04106ae84e193bbc87640cbf2de +size 364132 diff --git a/assets/depth/0.png b/assets/depth/0.png new file mode 100644 index 0000000000000000000000000000000000000000..f0b27506f0913d1cf9aab9b390d0c67063f150e1 --- /dev/null +++ b/assets/depth/0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0bb5fa3ffca3067c69ec6b1dfc600798491a1005f9dd2bdaf0c98c5b3a1d2ac +size 232828 diff --git a/assets/depth/1.png b/assets/depth/1.png new file mode 100644 index 0000000000000000000000000000000000000000..2c9ff050dcef9bedfdc65b0e6197dde7ac7ab5f0 --- /dev/null +++ b/assets/depth/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f642578b11c348dc12ccb0b3d19c146986e6abc17f8cfc02f1cf8e325cdaeaf0 +size 234367 diff --git a/assets/emergency.jpg b/assets/emergency.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2ccacfd27a6d7c0102ca67c6ba90e9ba575dfb12 --- /dev/null +++ b/assets/emergency.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db64fd16c971bd704deb7290f347bc784772e52f8718d78724d3805081a57ef7 +size 204125 diff --git a/assets/iclr_dataset_sample.jpg b/assets/iclr_dataset_sample.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4c9bfb23e332497341c41b0740562eab502883f8 --- /dev/null +++ b/assets/iclr_dataset_sample.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81dee815642f74a217e20138a60f9fa6bc76c2a5f2ae5faed18741ef755f6a6e +size 169016 diff --git a/assets/image/0.jpg b/assets/image/0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..553f0c2f8a039eb42aafeac4ad672a4b08eb5088 Binary files /dev/null and b/assets/image/0.jpg differ diff --git a/assets/image/1.jpg b/assets/image/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..41d2d17ff153661dbec7b63b2e5c6bc68d1a4217 Binary files /dev/null and b/assets/image/1.jpg differ diff --git a/assets/languagebind.jpg b/assets/languagebind.jpg new file mode 100644 index 0000000000000000000000000000000000000000..405a1d8543269d2fdc535fe0beb5c959ba234ee8 --- /dev/null +++ b/assets/languagebind.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df5faf91d750c28ce16a2ac919b2e277320274c8c1c3636aa572316adcb9c5c1 +size 272797 diff --git a/assets/languagebind_frame.jpg b/assets/languagebind_frame.jpg new file mode 100644 index 0000000000000000000000000000000000000000..01c7b523c577f07ef7cd457481125dd4c4ab0914 --- /dev/null +++ b/assets/languagebind_frame.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a400701a13ffdc459a5edc933aeb5290aa7114034ef8594f435b7906f15f767 +size 1356820 diff --git a/assets/languagebind_result.jpg b/assets/languagebind_result.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d9b9c419f493b227936e47ece0f9e79f68baaec8 --- /dev/null +++ b/assets/languagebind_result.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dac8188d8911a77ab9ecaeeb45303d39422af073c55b7a6785dff664ed4ce544 +size 441095 diff --git a/assets/languge_result.jpg b/assets/languge_result.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d9b9c419f493b227936e47ece0f9e79f68baaec8 --- /dev/null +++ b/assets/languge_result.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dac8188d8911a77ab9ecaeeb45303d39422af073c55b7a6785dff664ed4ce544 +size 441095 diff --git a/assets/logo.jpg b/assets/logo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4cfa5155b774c3035ddd946a3f256f1caaa47c7a --- /dev/null +++ b/assets/logo.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8cdf04f5629c0ffbdcb6dd0fd3ef9df91665361c8da7fa55aaf050ad33408c4c +size 914662 diff --git a/assets/logo_languagebind.png b/assets/logo_languagebind.png new file mode 100644 index 0000000000000000000000000000000000000000..13756e0bd0951b7e853abc99c1aae456c66748cc --- /dev/null +++ b/assets/logo_languagebind.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4b53c886ec8a771db8de8812f681aeb8e80a2457fb174aa12753ebf3a835507 +size 907672 diff --git a/assets/res1.jpg b/assets/res1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ab1c024afd6a1d8ae430e755be0b9b62fb955e56 Binary files /dev/null and b/assets/res1.jpg differ diff --git a/assets/res2.jpg b/assets/res2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e36ca721912b0dd109224ffec49c981d1ec51d48 Binary files /dev/null and b/assets/res2.jpg differ diff --git a/assets/result1.jpg b/assets/result1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cbe52cb8d2fcbc62820ce2e59d814e8ea248f774 --- /dev/null +++ b/assets/result1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:322aef993d3c5cec718cda144e9b7eb55751dcd5704d526dd1170a0cb04ff697 +size 142425 diff --git a/assets/sota.jpg b/assets/sota.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6c716570241e80eef1b06b00489bfc489af27f45 --- /dev/null +++ b/assets/sota.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:166389a5f6e92f21bbb5cc7b57d50df05b5fbee3fc7da6c5bb9dbb5d9a90666f +size 198994 diff --git a/assets/thermal/0.jpg b/assets/thermal/0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..497f15d31daa55f86ae83a23a66460b20da4b251 Binary files /dev/null and b/assets/thermal/0.jpg differ diff --git a/assets/thermal/1.jpg b/assets/thermal/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eaca85c14b8fa348f1db7494af87efc0c7091465 Binary files /dev/null and b/assets/thermal/1.jpg differ diff --git a/assets/video/0.mp4 b/assets/video/0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..82a44e52de082c8eabff23db954675761b9f1da3 --- /dev/null +++ b/assets/video/0.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d92bdf4ad672f6bc82a72c886c3c8bc7e799866bbe41b184d640a6c5f21a075 +size 661405 diff --git a/assets/video/1.mp4 b/assets/video/1.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..bb9c32015fa46113698872f630446ec9cd7655e9 --- /dev/null +++ b/assets/video/1.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a6dcc0228ffcadaaac4441476f02d3109c3f005af56aeb609a0ee1f66128b80 +size 590954 diff --git a/d_cls/cp_zero_shot_metadata.py b/d_cls/cp_zero_shot_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc3d68d2e97b6f1daaeabe35697d8cd04facd8a --- /dev/null +++ b/d_cls/cp_zero_shot_metadata.py @@ -0,0 +1,117 @@ +import os + +import pandas as pd + +OPENAI_IMAGENET_TEMPLATES = ( + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +) + + +# a much smaller subset of above prompts +# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +SIMPLE_IMAGENET_TEMPLATES = ( + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +) + + +IMAGENET_CLASSNAMES = ( + +) + + +CLASSNAMES = { + 'NYUV2': ( + "bathroom", "bedroom", "bookstore", "classroom", "dining room", + "home office", "kitchen", "living room", "office", "others" + ), + 'SUNRGBD': ( + "bathroom", "bedroom", "classroom", "computer room", "conference room", "corridor", "dining area", + "dining room", "discussion area", "furniture store", "home office", "kitchen", "lab", "lecture theatre", + "library", "living room", "office", "rest space", "study space" + ), +} diff --git a/d_cls/datasets.py b/d_cls/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..a750daa8c0ca2950ae8d02adb72edf9a69d5f72e --- /dev/null +++ b/d_cls/datasets.py @@ -0,0 +1,20 @@ +import cv2 +import torch + +from data.build_datasets import DataInfo +from data.process_depth import get_depth_transform, opencv_loader +from torchvision import datasets + +def get_depth_dataset(args): + data_path = args.depth_data_path + transform = get_depth_transform(args) + dataset = datasets.ImageFolder(data_path, transform=transform, loader=opencv_loader) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=None, + ) + + return DataInfo(dataloader=dataloader, sampler=None) diff --git a/d_cls/precision.py b/d_cls/precision.py new file mode 100644 index 0000000000000000000000000000000000000000..a63b92256518d13afd57261df1568e26b1622201 --- /dev/null +++ b/d_cls/precision.py @@ -0,0 +1,12 @@ +import torch +from contextlib import suppress + + +def get_autocast(precision): + if precision == 'amp': + return torch.cuda.amp.autocast + elif precision == 'amp_bfloat16' or precision == 'amp_bf16': + # amp_bfloat16 is more stable than amp float16 for clip training + return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return suppress diff --git a/d_cls/zero_shot.py b/d_cls/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..d5c005a6ad91e128c210e8ddb48310cad2bf6e8c --- /dev/null +++ b/d_cls/zero_shot.py @@ -0,0 +1,90 @@ +import logging + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from open_clip import get_input_dtype, get_tokenizer +from open_clip.factory import HF_HUB_PREFIX +from .precision import get_autocast +from .zero_shot_classifier import build_zero_shot_classifier +from .zero_shot_metadata import CLASSNAMES, OPENAI_IMAGENET_TEMPLATES + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] + + +def run(model, classifier, dataloader, args): + autocast = get_autocast(args.precision) + input_dtype = get_input_dtype(args.precision) + + with torch.no_grad(): + top1, top5, n = 0., 0., 0. + for images, target in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(device=args.device, dtype=input_dtype) + images = images.unsqueeze(2) + target = target.to(args.device) + + with autocast(): + # predict + output = model(image=images) + image_features = output['image_features'] if isinstance(output, dict) else output[0] + logits = 100. * image_features @ classifier + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = (top1 / n) + top5 = (top5 / n) + return top1, top5 + + +def zero_shot_eval(model, data, epoch, args): + temp_val_d_cls_data = args.val_d_cls_data + args.val_d_cls_data = list(data.keys()) + assert len(args.val_d_cls_data) == 1 + args.val_d_cls_data = args.val_d_cls_data[0] + + if args.val_d_cls_data not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + if args.distributed and not args.horovod: + model = model.module + + logging.info(f'Starting zero-shot {args.val_d_cls_data.upper()}.') + + logging.info('Building zero-shot classifier') + autocast = get_autocast(args.precision) + with autocast(): + tokenizer = get_tokenizer(HF_HUB_PREFIX+args.model, cache_dir=args.cache_dir) + # tokenizer = get_tokenizer("ViT-L-14") + classifier = build_zero_shot_classifier( + model, + tokenizer=tokenizer, + classnames=CLASSNAMES[args.val_d_cls_data], + templates=OPENAI_IMAGENET_TEMPLATES, + num_classes_per_batch=10, + device=args.device, + use_tqdm=True, + ) + + logging.info('Using classifier') + results = {} + if args.val_d_cls_data in data: + top1, top5 = run(model, classifier, data[args.val_d_cls_data].dataloader, args) + results[f'{args.val_d_cls_data}-zeroshot-val-top1'] = top1 + results[f'{args.val_d_cls_data}-zeroshot-val-top5'] = top5 + + logging.info(f'Finished zero-shot {args.val_d_cls_data.upper()}.') + + args.val_d_cls_data = temp_val_d_cls_data + return results diff --git a/d_cls/zero_shot_classifier.py b/d_cls/zero_shot_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a5267cea4119994e30bb4830a6744cf25bdbaf --- /dev/null +++ b/d_cls/zero_shot_classifier.py @@ -0,0 +1,111 @@ +from functools import partial +from itertools import islice +from typing import Callable, List, Optional, Sequence, Union + +import torch +import torch.nn.functional as F + + +def batched(iterable, n): + """Batch data into lists of length *n*. The last batch may be shorter. + NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch + + +def build_zero_shot_classifier( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + num_classes_per_batch: Optional[int] = 10, + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names in batches + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + num_classes_per_batch: The number of classes to batch together in each forward, all if None + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + use_format = isinstance(templates[0], str) + num_templates = len(templates) + num_classes = len(classnames) + if use_tqdm: + import tqdm + num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) + iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) + else: + iter_wrap = iter + + def _process_batch(batch_classnames): + num_batch_classes = len(batch_classnames) + texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] + input_ids, attention_mask = tokenizer(texts) + input_ids, attention_mask = input_ids.to(device), attention_mask.to(device) + class_embeddings = F.normalize(model.encode_text(input_ids, attention_mask), dim=-1) + class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) + class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) + class_embeddings = class_embeddings.T + return class_embeddings + + with torch.no_grad(): + if num_classes_per_batch: + batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] + zeroshot_weights = torch.cat(batched_embeds, dim=1) + else: + zeroshot_weights = _process_batch(classnames) + return zeroshot_weights + + +def build_zero_shot_classifier_legacy( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names 1 by 1 + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + if use_tqdm: + import tqdm + iter_wrap = tqdm.tqdm + else: + iter_wrap = iter + + use_format = isinstance(templates[0], str) + + with torch.no_grad(): + zeroshot_weights = [] + for classname in iter_wrap(classnames): + texts = [template.format(classname) if use_format else template(classname) for template in templates] + texts = tokenizer(texts).to(device) # tokenize + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) + + return zeroshot_weights + diff --git a/d_cls/zero_shot_metadata.py b/d_cls/zero_shot_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..871151ca71d53f25e21f19f67b680cb09cd778e2 --- /dev/null +++ b/d_cls/zero_shot_metadata.py @@ -0,0 +1,117 @@ +import os + +import pandas as pd + +OPENAI_IMAGENET_TEMPLATES = ( + lambda c: f'a bad depth photo of a {c}.', + lambda c: f'a depth photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a depth photo of the hard to see {c}.', + lambda c: f'a low resolution depth photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad depth photo of the {c}.', + lambda c: f'a cropped depth photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a depth photo of a hard to see {c}.', + lambda c: f'a bright depth photo of a {c}.', + lambda c: f'a depth photo of a clean {c}.', + lambda c: f'a depth photo of a dirty {c}.', + lambda c: f'a dark depth photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a depth photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a depth photo of the cool {c}.', + lambda c: f'a close-up depth photo of a {c}.', + lambda c: f'a black and white depth photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated depth photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright depth photo of the {c}.', + lambda c: f'a cropped depth photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a depth photo of the dirty {c}.', + lambda c: f'a jpeg corrupted depth photo of a {c}.', + lambda c: f'a blurry depth photo of the {c}.', + lambda c: f'a depth photo of the {c}.', + lambda c: f'a good depth photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a depth photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up depth photo of the {c}.', + lambda c: f'a depth photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution depth photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a depth photo of the clean {c}.', + lambda c: f'a depth photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a depth photo of a nice {c}.', + lambda c: f'a depth photo of a weird {c}.', + lambda c: f'a blurry depth photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated depth photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted depth photo of the {c}.', + lambda c: f'a good depth photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a depth photo of the nice {c}.', + lambda c: f'a depth photo of the small {c}.', + lambda c: f'a depth photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a depth photo of the large {c}.', + lambda c: f'a black and white depth photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark depth photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a depth photo of a cool {c}.', + lambda c: f'a depth photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +) + + +# a much smaller subset of above prompts +# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +SIMPLE_IMAGENET_TEMPLATES = ( + lambda c: f'itap of a {c}.', + lambda c: f'a bad depth photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a depth photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a depth photo of the small {c}.', +) + + +IMAGENET_CLASSNAMES = ( + +) + + +CLASSNAMES = { + 'NYUV2': ( + "bathroom", "bedroom", "bookstore", "classroom", "dining room", + "home office", "kitchen", "living room", "office", "others" + ), + 'SUNRGBD': ( + "bathroom", "bedroom", "classroom", "computer room", "conference room", "corridor", "dining area", + "dining room", "discussion area", "furniture store", "home office", "kitchen", "lab", "lecture theatre", + "library", "living room", "office", "rest space", "study space" + ), +} diff --git a/d_cls/zeroshot_cls.py b/d_cls/zeroshot_cls.py new file mode 100644 index 0000000000000000000000000000000000000000..03f4439518f298c76c751746a40fab2563361ec9 --- /dev/null +++ b/d_cls/zeroshot_cls.py @@ -0,0 +1,47 @@ + +import json +import logging +import os +from training.distributed import is_master +from .zero_shot import zero_shot_eval + +try: + import wandb +except ImportError: + wandb = None + + + +def evaluate_d_cls(model, data, epoch, args, tb_writer=None): + metrics = {} + if not is_master(args): + return metrics + model.eval() + + zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + metrics.update(zero_shot_metrics) + + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) + ) + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/d_cls/{args.val_d_cls_data[0].lower()}/{name}", val, epoch) + args.d_cls_output_dir = os.path.join(args.log_base_path, f'd_cls/{args.val_d_cls_data[0].lower()}') + os.makedirs(args.d_cls_output_dir, exist_ok=True) + with open(os.path.join(args.d_cls_output_dir, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, 'Please install wandb.' + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, 'epoch': epoch}) + + return metrics diff --git a/data/base_datasets.py b/data/base_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..b94297502b6c576e325742e1228355dd7218f3b7 --- /dev/null +++ b/data/base_datasets.py @@ -0,0 +1,215 @@ +import contextlib +import io +import json +import logging +import os.path +import random +import re +import time + +import pandas as pd + +from a_cls.dataloader import make_midname_dict +from open_clip import get_tokenizer +from open_clip.factory import HF_HUB_PREFIX +from .process_video import load_and_transform_video, get_video_transform +from .process_audio import load_and_transform_audio, get_audio_transform +from .process_text import load_and_transform_text +from .process_depth import load_and_transform_depth, get_depth_transform +from .process_thermal import load_and_transform_thermal, get_thermal_transform + +import argparse +from os.path import join as opj +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm + + + +class VAT_dataset(Dataset): + def __init__(self, args): + super().__init__() + self.video_decode_backend = args.video_decode_backend + self.num_frames = args.num_frames + self.text_type = args.text_type + self.total_text = ['raw', 'mplug', 'polish_mplug', 'sound_mplug'] + [f'ofa{i}' for i in range(8)] + self.weight = [0.2, 0.2, 0.2, 0.2] + [0.2 / 8] * 8 + self.title = self.text_type == 'raw' + self.data_root = '/apdcephfs_cq3/share_1311970/A_Youtube' + if args.clip_type != 'al': + with open(args.train_data, 'r') as f: + self.id2title_folder_caps = json.load(f) + self.ids = list(self.id2title_folder_caps.keys())[:args.train_num_samples] + else: + self.id2path_cap, self.ids = get_audio_anno() + + self.clip_type = args.clip_type + + self.num_mel_bins = args.num_mel_bins + self.target_length = args.target_length + self.audio_sample_rate = args.audio_sample_rate + self.audio_mean = args.audio_mean + self.audio_std = args.audio_std + + # self.audio_error_file = open('./audio_error_id.txt', 'w') + + self.tokenizer = get_tokenizer(HF_HUB_PREFIX + args.model, cache_dir=args.cache_dir) + self.video_transform = get_video_transform(args) + self.audio_transform = get_audio_transform(args) + self.depth_transform = get_depth_transform(args) + self.thermal_transform = get_thermal_transform(args) + + def __len__(self): + return len(self.ids) + # return self.id2title_folder_caps.shape[0] + + + def __getitem__(self, idx): + try: + if self.clip_type == 'al': + matched_modality, input_ids, attention_mask = self.get_audio_text(idx) + return matched_modality, input_ids, attention_mask + else: + id = self.ids[idx] + folder = self.id2title_folder_caps[id]['folder'] + text_output, ofa_number = self.get_text(id) + input_ids, attention_mask = text_output['input_ids'], text_output['attention_mask'] + if self.clip_type == 'vl' or self.clip_type == 'vl_new': + matched_modality = self.get_video(id, folder) + # elif self.clip_type == 'al': + # matched_modality = self.get_audio(id, folder) + elif self.clip_type == 'dl': + matched_modality = self.get_depth(id, folder, ofa_number) + elif self.clip_type == 'tl': + matched_modality = self.get_thermal(id, folder, ofa_number) + return matched_modality['pixel_values'], input_ids, attention_mask + except Exception as error_msg: + logging.info(f"Failed at {idx} with \"{error_msg}\"") + return self.__getitem__(random.randint(0, self.__len__()-1)) + + def get_video(self, id, folder): + # video_path = opj(self.data_root, folder, f'{id}.mp4') + resize_folder = 'new_download_resize256_skip15' if folder.startswith('new_') else f'{folder}_resize256_skip15' + video_path = opj(self.data_root, resize_folder, f'{id}.mp4') + video = load_and_transform_video(video_path, self.video_transform, + video_decode_backend=self.video_decode_backend, num_frames=self.num_frames) + return video + + def get_audio_text(self, idx): + + path_cap = self.id2path_cap[self.ids[idx]] + audio_path = path_cap['path'] + audio_data = load_and_transform_audio(audio_path, self.audio_transform) + + caption = path_cap['caption'] + if isinstance(caption, list): + if isinstance(caption[0], str) and len(caption) > 1: + caption = random.choice(caption) + else: + caption = caption[0] + + input_ids, attention_mask = self.tokenizer(caption) + + return audio_data, input_ids.squeeze(), attention_mask.squeeze() + + # def get_audio(self, idx): + ''' + audio_path = opj(self.data_root, folder, f'{id}.mp3') + if os.path.exists(audio_path): + pass + else: + audio_path = audio_path[:-4] + '.m4a' + if os.path.exists(audio_path): + pass + else: + audio_path = audio_path[:-4] + '.wav' + if not os.path.exists(audio_path): + # self.audio_error_file.write(audio_path[:-4] + '\n') + raise FileNotFoundError(f'Not found audio file at \'{audio_path[:-4]}\' with .mp3 .m4a .wav') + # AudioSegment.from_file(audio_path).export(audio_path[:-4] + '.mp3', format='mp3') + # audio_path = opj(self.data_root, folder, f'{id}.mp3') + audio = load_and_transform_audio(audio_path, self.audio_transform) + ''' + + # audio_path = opj(self.data_root, folder+'_ffmpeg_mp3', f'{id}.mp3') + # audio = load_and_transform_audio(audio_path, self.audio_transform) + + + ''' + audiocap_id = self.meta['uniq_id'][idx] + audio_path = f'/apdcephfs_cq3/share_1311970/downstream_datasets/Audio/audiocaps/audio/train/{audiocap_id}.flac' + audio_data = load_and_transform_audio(audio_path, self.audio_transform) + + caption = self.meta['text'][idx] + input_ids, attention_mask = self.tokenizer(caption) + return audio_data, input_ids.squeeze(), attention_mask.squeeze() + ''' + + ''' + path_cap = self.id2path_cap[self.ids[idx]] + audio_path = f"/remote-home/freesound/{path_cap['path']}" + audio_data = load_and_transform_audio(audio_path, self.audio_transform) + + caption = path_cap['caption'] + input_ids, attention_mask = self.tokenizer(caption) + ''' + + # return audio + + + def get_text(self, id): + if self.text_type != 'mix': + text = self.id2title_folder_caps[id][self.text_type] + text_output = load_and_transform_text(text, self.tokenizer, title=self.title) + return text_output, None + else: + text_type = random.choices(self.total_text, self.weight)[0] + ofa_number = None + if text_type.startswith('ofa'): + ofa_number = int(text_type[-1]) + text = self.id2title_folder_caps[id]['ofa'][ofa_number] + else: + text = self.id2title_folder_caps[id][text_type] + text_output = load_and_transform_text(text, self.tokenizer, title=text_type=='raw') + return text_output, ofa_number + + def get_depth(self, id, folder, ofa_number): + depth_folder = opj(self.data_root, folder, f'{id}_depth_f8glpn_folder') + random_id = random.randint(0, 7) if ofa_number is None else ofa_number + # random_id = 3 + depth_path = os.path.join(depth_folder, f'{random_id}.png') + depth = load_and_transform_depth(depth_path, self.depth_transform) + return depth + + def get_thermal(self, id, folder, ofa_number): + thermal_folder = opj(self.data_root, folder, f'{id}_thermal_folder') + random_id = random.randint(0, 7) if ofa_number is None else ofa_number + # random_id = 3 + thermal_path = os.path.join(thermal_folder, f'{random_id}.jpg') + thermal = load_and_transform_thermal(thermal_path, self.thermal_transform) + return thermal + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('Pre-training', add_help=False) + parser.add_argument('--num_frames', default=8, type=float, help='') + parser.add_argument('--workers', default=10, type=int, help='') + args = parser.parse_args() + + args.cache_dir = 'D:\Omni-modal-hf' + args.num_frames = 8 + args.clip_type = 'vl' + args.num_mel_bins = 128 + args.target_length = 1024 + args.audio_sample_rate = 16000 + args.audio_mean = 1 + args.audio_std = 1 + args.rank = 0 + args.batch_size = 16 + + train_dataset = VAT_dataset(args) + load = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers) + + for samples in tqdm((load)): + matched_modality, input_ids, attention_mask = samples + # print(video.shape, text.shape) diff --git a/data/bpe_simple_vocab_16e6.txt.gz b/data/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/data/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/data/build_datasets.py b/data/build_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1eb4a47c093e40a3c612d8af4c97a00141b38e --- /dev/null +++ b/data/build_datasets.py @@ -0,0 +1,247 @@ +import os +import time +from dataclasses import dataclass +from multiprocessing import Value + +import torch +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from data.base_datasets import VAT_dataset +from data.new_loadvat import get_wds_dataset +from open_clip import get_tokenizer +from open_clip.factory import HF_HUB_PREFIX + + +class SharedEpoch: + def __init__(self, epoch: int = 0): + self.shared_epoch = Value('i', epoch) + + def set_value(self, epoch): + self.shared_epoch.value = epoch + + def get_value(self): + return self.shared_epoch.value + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler = None + shared_epoch: SharedEpoch = None + + def set_epoch(self, epoch): + if self.shared_epoch is not None: + self.shared_epoch.set_value(epoch) + if self.sampler is not None and isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + +def get_VAT_dataset(args): + dataset = VAT_dataset(args) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed else None + shuffle = sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + # prefetch_factor=2, + # persistent_workers=True, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=True, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + +def get_data(args, epoch=0): + data = {} + + if args.do_train: + if args.train_data.endswith(".json"): + data[f"{args.clip_type}_pt"] = get_VAT_dataset(args) + elif args.train_data.endswith(".tar"): + data[f"{args.clip_type}_pt"] = get_wds_dataset(args, is_train=True, epoch=epoch) + else: + raise NameError + + if args.do_eval: + temp_batch_size = args.batch_size + args.batch_size = 8 if args.val_vl_ret_data else 16 + data_root = "/apdcephfs_cq3/share_1311970/downstream_datasets/VideoTextRetrieval/vtRetdata" + if args.val_vl_ret_data: + data["vl_ret"] = [] + for val_vl_ret_data in args.val_vl_ret_data: + if val_vl_ret_data == "msrvtt": + args.train_csv = os.path.join(f'{data_root}/MSRVTT/MSRVTT_train.9k.csv') + args.val_csv = os.path.join(f'{data_root}/MSRVTT/MSRVTT_JSFUSION_test.csv') + args.data_path = os.path.join(f'{data_root}/MSRVTT/MSRVTT_data.json') + args.features_path = os.path.join(f'{data_root}/MSRVTT/MSRVTT_Videos') + elif val_vl_ret_data == "msvd": + args.data_path = os.path.join(f'{data_root}/MSVD') + args.features_path = os.path.join(f'{data_root}/MSVD/MSVD_Videos') + elif val_vl_ret_data == "activity": + args.data_path = os.path.join(f'{data_root}/ActivityNet') + args.features_path = os.path.join(f'{data_root}/ActivityNet/Videos/Activity_Videos') + elif val_vl_ret_data == "didemo": + args.data_path = os.path.join(f'{data_root}/Didemo') + args.features_path = os.path.join(f'{data_root}/Didemo/videos') + else: + raise NameError + + args.batch_size_val = args.batch_size if args.batch_size_val == 0 else args.batch_size_val + args.max_frames = args.num_frames + args.num_thread_reader = args.workers + args.slice_framepos = 2 # "0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly." + + from vl_ret.data_dataloaders import DATALOADER_DICT + + tokenizer = get_tokenizer(HF_HUB_PREFIX + args.model, cache_dir=args.cache_dir) + test_dataloader, test_length = None, 0 + if DATALOADER_DICT[val_vl_ret_data]["test"] is not None: + test_dataloader, test_length = DATALOADER_DICT[val_vl_ret_data]["test"](args, tokenizer) + + if DATALOADER_DICT[val_vl_ret_data]["val"] is not None: + val_dataloader, val_length = DATALOADER_DICT[val_vl_ret_data]["val"](args, tokenizer, subset="val") + else: + val_dataloader, val_length = test_dataloader, test_length + ## report validation results if the ["test"] is None + if test_dataloader is None: + test_dataloader, test_length = val_dataloader, val_length + + data["vl_ret"].append({val_vl_ret_data: test_dataloader}) + + if args.val_v_cls_data: + data["v_cls"] = [] + temp_val_v_cls_data = args.val_v_cls_data + for val_v_cls_data in temp_val_v_cls_data: + from v_cls import get_video_cls_dataloader + args.val_v_cls_data = val_v_cls_data + if args.val_v_cls_data == 'Kinetics-400': + args.video_data_path = "/apdcephfs_cq3/share_1311970/downstream_datasets/VideoCls/new_k400/Kinetics-400/raw/Kinetics-400" + args.nb_classes = 400 + elif args.val_v_cls_data == 'Kinetics-600': + args.video_data_path = "/apdcephfs_cq3/share_1311970/downstream_datasets/VideoCls/new_k600/Kinetics600/raw/Kinetics600" + args.nb_classes = 600 + args.data_root = args.video_data_path + args.data_set = val_v_cls_data + args.dist_eval = True + args.sampling_rate = 8 + args.num_sample = 1 + args.test_num_segment = 5 + args.test_num_crop = 3 + args.num_workers = args.workers + data['v_cls'].append({val_v_cls_data: get_video_cls_dataloader(args)}) + args.val_v_cls_data = temp_val_v_cls_data + + if args.val_a_cls_data: + temp_audio_mean, temp_audio_std = args.audio_mean, args.audio_std + args.audio_mean, args.audio_std = -4.2677393, 4.5689974 + data["a_cls"] = [] + data_root = "/apdcephfs_cq3/share_1311970/downstream_datasets/Audio" + temp_val_a_cls_data = args.val_a_cls_data + for val_a_cls_data in temp_val_a_cls_data: + from a_cls.datasets import get_audio_dataset + args.val_a_cls_data = val_a_cls_data + args.audio_data_path = os.path.join(data_root, f'{val_a_cls_data.lower()}/test') + data['a_cls'].append({val_a_cls_data: get_audio_dataset(args)}) + args.val_a_cls_data = temp_val_a_cls_data + args.audio_mean, args.audio_mean = temp_audio_mean, temp_audio_std + + if args.val_al_ret_data: + temp_audio_mean, temp_audio_std = args.audio_mean, args.audio_std + args.audio_mean, args.audio_std = -4.2677393, 4.5689974 + + data["al_ret"] = [] + data_root = "/apdcephfs_cq3/share_1311970/downstream_datasets/Audio" + temp_val_al_ret_data = args.val_al_ret_data + for val_al_ret_data in temp_val_al_ret_data: + from al_ret.datasets import get_audio_dataset + args.val_al_ret_data = val_al_ret_data + if val_al_ret_data.lower() != 'msrvtt': + args.audio_data_path = os.path.join(data_root, val_al_ret_data.lower()) + data['al_ret'].append({val_al_ret_data: get_audio_dataset(args)}) + elif val_al_ret_data.lower() == 'msrvtt': + args.train_csv = os.path.join(f'/apdcephfs_cq3/share_1311970/downstream_datasets/VideoTextRetrieval/vtRetdata/MSRVTT/MSRVTT_train.9k.csv') + args.val_csv = os.path.join(f'/apdcephfs_cq3/share_1311970/downstream_datasets/VideoTextRetrieval/Audio/MSRVTT/MSRVTT_AUDIO_test.csv') + args.data_path = os.path.join(f'/apdcephfs_cq3/share_1311970/downstream_datasets/VideoTextRetrieval/vtRetdata/MSRVTT/MSRVTT_data.json') + args.features_path = os.path.join(f'/apdcephfs_cq3/share_1311970/downstream_datasets/VideoTextRetrieval/Audio/MSRVTT/videos/all') + + + args.num_thread_reader = args.workers + from al_ret.data_dataloaders import DATALOADER_DICT + args.batch_size_val = args.batch_size if args.batch_size_val == 0 else args.batch_size_val + + tokenizer = get_tokenizer(HF_HUB_PREFIX + args.model, cache_dir=args.cache_dir) + test_dataloader, test_length = None, 0 + if DATALOADER_DICT[val_al_ret_data.lower()]["test"] is not None: + test_dataloader, test_length = DATALOADER_DICT[val_al_ret_data.lower()]["test"](args, tokenizer) + + if DATALOADER_DICT[val_al_ret_data.lower()]["val"] is not None: + val_dataloader, val_length = DATALOADER_DICT[val_al_ret_data.lower()]["val"](args, tokenizer, subset="val") + else: + val_dataloader, val_length = test_dataloader, test_length + ## report validation results if the ["test"] is None + if test_dataloader is None: + test_dataloader, test_length = val_dataloader, val_length + data['al_ret'].append({val_al_ret_data: test_dataloader}) + + args.val_al_ret_data = temp_val_al_ret_data + args.audio_mean, args.audio_mean = temp_audio_mean, temp_audio_std + + if args.val_a_cls_data: + temp_audio_mean, temp_audio_std = args.audio_mean, args.audio_std + args.audio_mean, args.audio_std = -4.2677393, 4.5689974 + data["a_cls"] = [] + data_root = "/apdcephfs_cq3/share_1311970/downstream_datasets/Audio" + temp_val_a_cls_data = args.val_a_cls_data + for val_a_cls_data in temp_val_a_cls_data: + from a_cls.datasets import get_audio_dataset + args.val_a_cls_data = val_a_cls_data + args.audio_data_path = os.path.join(data_root, f'{val_a_cls_data.lower()}/test') + data['a_cls'].append({val_a_cls_data: get_audio_dataset(args)}) + args.val_a_cls_data = temp_val_a_cls_data + args.audio_mean, args.audio_mean = temp_audio_mean, temp_audio_std + + if args.imagenet_val is not None: + from i_cls.datasets import get_imagenet + data['i_cls'] = {} + data['i_cls']["imagenet-val"] = get_imagenet(args, "val") + if args.imagenet_v2 is not None: + from i_cls.datasets import get_imagenet + if data.get('i_cls', None) is None: + data['i_cls'] = {} + data['i_cls']["imagenet-v2"] = get_imagenet(args, "v2") + + if args.val_d_cls_data: + data["d_cls"] = [] + data_root = "/apdcephfs_cq3/share_1311970/downstream_datasets/Depth" + temp_val_d_cls_data = args.val_d_cls_data + for val_d_cls_data in temp_val_d_cls_data: + from d_cls.datasets import get_depth_dataset + args.val_d_cls_data = val_d_cls_data + args.depth_data_path = os.path.join(data_root, f'{val_d_cls_data.lower()}/data/val') + data['d_cls'].append({val_d_cls_data: get_depth_dataset(args)}) + args.val_d_cls_data = temp_val_d_cls_data + + + if args.val_t_cls_data: + data["t_cls"] = [] + data_root = "/apdcephfs_cq3/share_1311970/downstream_datasets/Thermal" + temp_val_t_cls_data = args.val_t_cls_data + for val_t_cls_data in temp_val_t_cls_data: + from t_cls.datasets import get_thermal_dataset + args.val_t_cls_data = val_t_cls_data + args.thermal_data_path = os.path.join(data_root, f'{val_t_cls_data.lower()}/val') + data['t_cls'].append({val_t_cls_data: get_thermal_dataset(args)}) + args.val_t_cls_data = temp_val_t_cls_data + + args.batch_size = temp_batch_size + + return data + + + diff --git a/data/new_loadvat.py b/data/new_loadvat.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc879f6b33b12760efdf685b46315d6825f8b86 --- /dev/null +++ b/data/new_loadvat.py @@ -0,0 +1,498 @@ +import ast +import io +import json +import logging +import math +import os +import random +import sys +import braceexpand +from dataclasses import dataclass +from multiprocessing import Value + +import numpy.lib.format +import numpy as np +import pandas as pd +import torch +import torchvision.datasets as datasets +import webdataset as wds +from PIL import Image +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info +from torch.utils.data.distributed import DistributedSampler +from torchvision.transforms import ToTensor +from tqdm import tqdm +from webdataset.filters import _shuffle +from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample + +from open_clip import get_tokenizer +from open_clip.factory import HF_HUB_PREFIX +from training.params import parse_args +from data.process_text import load_and_transform_text +from data.process_video import get_video_transform +from data.process_audio import get_audio_transform +from data.process_depth import get_depth_transform +from data.process_thermal import get_thermal_transform +import pdb +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + + +class SharedEpoch: + def __init__(self, epoch: int = 0): + self.shared_epoch = Value('i', epoch) + + def set_value(self, epoch): + self.shared_epoch.value = epoch + + def get_value(self): + return self.shared_epoch.value + + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler = None + shared_epoch: SharedEpoch = None + + def set_epoch(self, epoch): + if self.shared_epoch is not None: + self.shared_epoch.set_value(epoch) + if self.sampler is not None and isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + + +def expand_urls(urls, weights=None): + if weights is None: + expanded_urls = wds.shardlists.expand_urls(urls) + return expanded_urls, None + if isinstance(urls, str): + urllist = urls.split("::") + weights = weights.split('::') + assert len(weights) == len(urllist), \ + f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." + weights = [float(weight) for weight in weights] + all_urls, all_weights = [], [] + for url, weight in zip(urllist, weights): + expanded_url = list(braceexpand.braceexpand(url)) + expanded_weights = [weight for _ in expanded_url] + all_urls.extend(expanded_url) + all_weights.extend(expanded_weights) + return all_urls, all_weights + else: + all_urls = list(urls) + return all_urls, weights + + +def get_dataset_size(shards): + shards_list, _ = expand_urls(shards) + dir_path = os.path.dirname(shards_list[0]) + sizes_filename = os.path.join(dir_path, 'sizes.json') + len_filename = os.path.join(dir_path, '__len__') + if os.path.exists(sizes_filename): + sizes = json.load(open(sizes_filename, 'r')) + total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) + elif os.path.exists(len_filename): + # FIXME this used to be eval(open(...)) but that seemed rather unsafe + total_size = ast.literal_eval(open(len_filename, 'r').read()) + else: + total_size = None # num samples undefined + # some common dataset sizes (at time of authors last download) + # CC3M (train): 2905954 + # CC12M: 10968539 + # LAION-400M: 407332084 + # LAION-2B (english): 2170337258 + num_shards = len(shards_list) + return total_size, num_shards + + + +def count_samples(dataloader): + os.environ["WDS_EPOCH"] = "0" + n_elements, n_batches = 0, 0 + for images, texts in dataloader: + n_batches += 1 + n_elements += len(images) + assert len(images) == len(texts) + return n_elements, n_batches + + +def filter_no_caption_or_no_image(sample): + has_caption = ('raw.txt' in sample and 'mplug.txt' in sample and 'polish_mplug.txt' in sample and 'ofa3.txt' in sample) + has_image = ('frm7.jpg' in sample and 'tml0.jpg' in sample and 'dep0.npy' in sample) + return has_caption and has_image + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, issue a warning, and continue.""" + logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') + return True + + +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_to_samples_nothrow(src, handler=log_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + + +def pytorch_worker_seed(increment=0): + """get dataloader worker seed from pytorch""" + worker_info = get_worker_info() + if worker_info is not None: + # favour using the seed already created for pytorch dataloader workers if it exists + seed = worker_info.seed + if increment: + # space out seed increments so they can't overlap across workers in different iterations + seed += increment * max(1, worker_info.num_workers) + return seed + # fallback to wds rank based seed + return wds.utils.pytorch_worker_seed() + + +_SHARD_SHUFFLE_SIZE = 200 +_SHARD_SHUFFLE_INITIAL = 50 +_SAMPLE_SHUFFLE_SIZE = 500 +_SAMPLE_SHUFFLE_INITIAL = 100 + + +class detshuffle2(wds.PipelineStage): + def __init__( + self, + bufsize=1000, + initial=100, + seed=0, + epoch=-1, + ): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + rng = random.Random() + if self.seed < 0: + # If seed is negative, we use the worker's seed, this will be different across all nodes/workers + seed = pytorch_worker_seed(epoch) + else: + # This seed to be deterministic AND the same across all nodes/workers in each epoch + seed = self.seed + epoch + rng.seed(seed) + return _shuffle(src, self.bufsize, self.initial, rng) + + +class ResampledShards2(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__( + self, + urls, + weights=None, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, + epoch=-1, + ): + """Sample shards from the shard list with replacement. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls, weights = expand_urls(urls, weights) + self.urls = urls + self.weights = weights + if self.weights is not None: + assert len(self.urls) == len(self.weights), \ + f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match." + assert isinstance(self.urls[0], str) + self.nshards = nshards + self.rng = random.Random() + self.worker_seed = worker_seed + self.deterministic = deterministic + self.epoch = epoch + + def __iter__(self): + """Return an iterator over the shards.""" + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + if self.deterministic: + # reset seed w/ epoch if deterministic + if self.worker_seed is None: + # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id + seed = pytorch_worker_seed(epoch) + else: + seed = self.worker_seed() + epoch + self.rng.seed(seed) + for _ in range(self.nshards): + if self.weights is None: + yield dict(url=self.rng.choice(self.urls)) + else: + yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]) + + +class Decode: + def __init__(self, args=None): + self.num_frames = args.num_frames + self.text_type = args.text_type + self.chatgpt = self.text_type == 'polish_mplug' + self.title = self.text_type == 'raw' + self.clip_type = args.clip_type + self.tokenizer = get_tokenizer(HF_HUB_PREFIX + args.model, cache_dir=args.cache_dir) + self.video_transform = get_video_transform(args) + self.audio_transform = get_audio_transform(args) + self.depth_transform = get_depth_transform(args) + self.thermal_transform = get_thermal_transform(args) + + + def __call__(self, sample): + input_ids, attention_mask = self.get_text(sample[f"{self.text_type}.txt"], chatgpt=self.chatgpt, title=self.title) + if self.clip_type == 'vl': + matched_modality = self.get_video([sample[f"frm{i}.jpg"] for i in range(self.num_frames)]) + elif self.clip_type == 'al': + matched_modality = self.get_audio() + elif self.clip_type == 'dl': + matched_modality = self.get_depth(sample[f"dep0.npy"]) + elif self.clip_type == 'tl': + matched_modality = self.get_thermal(sample[f"tml0.jpg"]) + # matched_modality = self.get_thermal(sample[f"tml{random.randint(0, 7)}.jpg"]) + else: + raise ValueError + return matched_modality, input_ids, attention_mask + + + def get_video(self, frames): + video_data = [] + for frame in frames: + with io.BytesIO(frame) as stream: + img = Image.open(stream) + img.load() + assert min(img.size) == 256 + result = ToTensor()(img) + video_data.append(result) + video_data = torch.stack(video_data, dim=1) + # video_data torch.Size([3, 8, 455, 256]) + # video_outputs torch.Size([3, 8, 224, 224]) + video_outputs = self.video_transform(video_data) + return video_outputs + + + def get_text(self, text, chatgpt=True, title=False): + text = text.decode("utf-8") + if chatgpt: + assert text.startswith('In the video, ') + text = text[14:] + tokens = load_and_transform_text(text, self.tokenizer, title=title) + return tokens['input_ids'], tokens['attention_mask'] + + def get_audio(self): + raise NotImplementedError + + def get_depth(self, depth): + stream = io.BytesIO(depth) + img = numpy.lib.format.read_array(stream) + depth = self.depth_transform(img) + return depth + + def get_thermal(self, thermal): + with io.BytesIO(thermal) as stream: + img = Image.open(stream) + img.load() + thermal = self.thermal_transform(img) + return thermal + +def get_wds_dataset(args, is_train, epoch=0, floor=False): + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + resampled = getattr(args, 'dataset_resampled', False) and is_train + + num_shards = None + if is_train: + if args.train_num_samples is not None: + num_samples = args.train_num_samples + else: + num_samples, num_shards = get_dataset_size(input_shards) + if not num_samples: + raise RuntimeError( + 'Currently, the number of dataset samples must be specified for the training dataset. ' + 'Please specify it via `--train-num-samples` if no dataset length info is present.') + else: + # Eval will just exhaust the iterator if the size is not specified. + num_samples = args.val_num_samples or 0 + + shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + + if resampled: + pipeline = [ResampledShards2( + input_shards, + weights=args.train_data_upsampling_factors, + deterministic=True, + epoch=shared_epoch, + )] + else: + assert args.train_data_upsampling_factors is None, \ + "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)." + pipeline = [wds.SimpleShardList(input_shards)] + + # at this point we have an iterator over all the shards + if is_train: + if not resampled: + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ]) + pipeline.extend([ + # at this point, we have an iterator over the shards assigned to each worker at each node + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ]) + else: + pipeline.extend([ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + pipeline.extend([ + wds.select(filter_no_caption_or_no_image), + # wds.decode("pilrgb", handler=log_and_continue), + # wds.rename(image="jpg;png;jpeg;webp", text="txt"), + # wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), + # wds.to_tuple("image", "text"), + wds.map(Decode(args), handler=log_and_continue), + wds.batched(args.batch_size, partial=not is_train) + ]) + + dataset = wds.DataPipeline(*pipeline) + + if is_train: + if not resampled: + num_shards = num_shards or len(expand_urls(input_shards)[0]) + assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=args.workers > 0, + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + + + +def get_data(args, epoch=0): + data = {} + + data["train"] = get_wds_dataset(args, is_train=True, epoch=epoch) + + return data + + +if __name__ == '__main__': + args = parse_args(sys.argv[1:]) + args.workers = 10 + args.batch_size = 16 + args.world_size = 1 + args.num_frames = 8 + args.clip_type = 'vl' + args.model = "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K" + args.train_data = '/apdcephfs_cq3/share_1311970/lb/vat2webdata/check_8frm_title_ofa_polishmplug_1tml_1dep/{00000..03020}.tar' + args.train_num_samples = 10_000 + args.dataset_type = 'webdataset' + + + + data = get_data(args, epoch=0) + + data['train'].set_epoch(0) # set epoch in process safe manner via sampler or shared_epoch + dataloader = data['train'].dataloader + num_batches_per_epoch = dataloader.num_batches // args.accum_freq + print(num_batches_per_epoch) + + + for i, batch in enumerate(tqdm(dataloader)): + images, input_ids, attention_mask = batch + # print(images.shape, input_ids.shape, attention_mask.shape) + # break \ No newline at end of file diff --git a/data/process_audio.py b/data/process_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..f9644146d793d5730392a3590fe2b92694e564aa --- /dev/null +++ b/data/process_audio.py @@ -0,0 +1,118 @@ +import logging + +import numpy as np +import torch +import torchaudio +import torchvision +from torchvision.transforms import transforms +from torch.nn import functional as F + +torchaudio.set_audio_backend("soundfile") + +def torchaudio_loader(path): + return torchaudio.load(path) + +def int16_to_float32_torch(x): + return (x / 32767.0).type(torch.float32) + +def float32_to_int16_torch(x): + x = torch.clamp(x, min=-1., max=1.) + return (x * 32767.).type(torch.int16) + +DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 + +class AudioTransform: + def __init__(self, args): + self.sample_rate = args.audio_sample_rate + self.num_mel_bins = args.num_mel_bins + self.target_length = args.target_length + self.audio_mean = args.audio_mean + self.audio_std = args.audio_std + self.mean = [] + self.std = [] + # mean=-4.2677393 + # std=4.5689974 + # self.norm = transforms.Normalize(mean=self.audio_mean, std=self.audio_std) + + + def __call__(self, audio_data_and_origin_sr): + audio_data, origin_sr = audio_data_and_origin_sr + if self.sample_rate != origin_sr: + # print(audio_data.shape, origin_sr) + audio_data = torchaudio.functional.resample(audio_data, orig_freq=origin_sr, new_freq=self.sample_rate) + waveform_melspec = self.waveform2melspec(audio_data) + return waveform_melspec + + + def waveform2melspec(self, audio_data): + mel = self.get_mel(audio_data) + if mel.shape[0] > self.target_length: + # split to three parts + chunk_frames = self.target_length + total_frames = mel.shape[0] + ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) + # print('total_frames-chunk_frames:', total_frames-chunk_frames, + # 'len(audio_data):', len(audio_data), + # 'chunk_frames:', chunk_frames, + # 'total_frames:', total_frames) + if len(ranges[1]) == 0: # if the audio is too short, we just use the first chunk + ranges[1] = [0] + if len(ranges[2]) == 0: # if the audio is too short, we just use the first chunk + ranges[2] = [0] + # randomly choose index for each part + idx_front = np.random.choice(ranges[0]) + idx_middle = np.random.choice(ranges[1]) + idx_back = np.random.choice(ranges[2]) + # idx_front = ranges[0][0] # fixed + # idx_middle = ranges[1][0] + # idx_back = ranges[2][0] + # select mel + mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :] + mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :] + mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :] + # print(total_frames, idx_front, idx_front + chunk_frames, idx_middle, idx_middle + chunk_frames, idx_back, idx_back + chunk_frames) + # stack + mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0) + elif mel.shape[0] < self.target_length: # padding if too short + n_repeat = int(self.target_length / mel.shape[0]) + 1 + # print(self.target_length, mel.shape[0], n_repeat) + mel = mel.repeat(n_repeat, 1)[:self.target_length, :] + mel_fusion = torch.stack([mel, mel, mel], dim=0) + else: # if equal + mel_fusion = torch.stack([mel, mel, mel], dim=0) + mel_fusion = mel_fusion.transpose(1, 2) # [3, target_length, mel_bins] -> [3, mel_bins, target_length] + + # self.mean.append(mel_fusion.mean()) + # self.std.append(mel_fusion.std()) + mel_fusion = (mel_fusion - self.audio_mean) / (self.audio_std * 2) + return mel_fusion + + def get_mel(self, audio_data): + # mel shape: (n_mels, T) + audio_data -= audio_data.mean() + mel = torchaudio.compliance.kaldi.fbank( + audio_data, + htk_compat=True, + sample_frequency=self.sample_rate, + use_energy=False, + window_type="hanning", + num_mel_bins=self.num_mel_bins, + dither=0.0, + frame_length=25, + frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, + ) + return mel # (T, n_mels) + + + +def get_audio_transform(args): + return AudioTransform(args) + +def load_and_transform_audio( + audio_path, + transform, +): + waveform_and_sr = torchaudio_loader(audio_path) + audio_outputs = transform(waveform_and_sr) + + return audio_outputs \ No newline at end of file diff --git a/data/process_depth.py b/data/process_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..bd33584022802f6e59c94185543ec3347c655f99 --- /dev/null +++ b/data/process_depth.py @@ -0,0 +1,55 @@ +import PIL +import cv2 +import numpy as np +import torch +from PIL import Image +from torch import nn +from torchvision import transforms +from open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +def opencv_loader(path): + return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype('float32') + + +class DepthNorm(nn.Module): + def __init__( + self, + max_depth=0, + min_depth=0.01, + ): + super().__init__() + self.max_depth = max_depth + self.min_depth = min_depth + self.scale = 1000.0 # nyuv2 abs.depth + + def forward(self, image): + # image = np.array(image) + depth_img = image / self.scale # (H, W) in meters + depth_img = depth_img.clip(min=self.min_depth) + if self.max_depth != 0: + depth_img = depth_img.clip(max=self.max_depth) + depth_img /= self.max_depth # 0-1 + else: + depth_img /= depth_img.max() + depth_img = torch.from_numpy(depth_img).unsqueeze(0).repeat(3, 1, 1) # assume image + return depth_img.to(torch.get_default_dtype()) + +def get_depth_transform(args): + transform = transforms.Compose( + [ + DepthNorm(max_depth=args.max_depth), + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD), # assume image + # transforms.Normalize((0.5, ), (0.5, )) # 0-1 to norm distribution + # transforms.Normalize((0.0418, ), (0.0295, )) # sun rgb-d imagebind + # transforms.Normalize((0.02, ), (0.00295, )) # nyuv2 + ] + ) + return transform + +def load_and_transform_depth(depth_path, transform): + depth = opencv_loader(depth_path) + depth_outputs = transform(depth) + return {'pixel_values': depth_outputs} diff --git a/data/process_image.py b/data/process_image.py new file mode 100644 index 0000000000000000000000000000000000000000..76ae4497b3fc878774401d2a2584d706f50557bd --- /dev/null +++ b/data/process_image.py @@ -0,0 +1,25 @@ +from PIL import Image + +from open_clip import image_transform, OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +def image_loader(path): + return Image.open(path) + +def get_image_transform(args): + preprocess_val = image_transform( + args.image_size, + is_train=False, + mean=OPENAI_DATASET_MEAN, + std=OPENAI_DATASET_STD, + ) + return preprocess_val + +def load_and_transform_image( + image_path, + transform, +): + image = image_loader(image_path) + image_outputs = transform(image) + + return {'pixel_values': image_outputs} \ No newline at end of file diff --git a/data/process_text.py b/data/process_text.py new file mode 100644 index 0000000000000000000000000000000000000000..cea7d216d9ecb789f5e3eb987a9b5bb1dbcfd1f0 --- /dev/null +++ b/data/process_text.py @@ -0,0 +1,202 @@ +import os + +import torch +import gzip +import html +import io +from functools import lru_cache +from typing import List, Tuple + +import ftfy +import regex as re +from iopath.common.file_io import g_pathmgr +BPE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + +# Modified from github.com/openai/CLIP +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str, context_length=77): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + with g_pathmgr.open(bpe_path, "rb") as fh: + bpe_bytes = io.BytesIO(fh.read()) + merges: List[str] = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges: List[Tuple[str, ...]] = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + self.context_length = context_length + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + def __call__(self, texts, context_length=None): + if not context_length: + context_length = self.context_length + + if isinstance(texts, str): + texts = [texts] + + sot_token = self.encoder["<|startoftext|>"] + eot_token = self.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + tokens = tokens[:context_length] + result[i, : len(tokens)] = torch.tensor(tokens) + + if len(result) == 1: + return result[0] + return result + +def clean_youtube(text, is_tags=False): + text = text.lower() + ' ' + text = re.sub( + r'#video|video|#shorts|shorts| shorts|#short| short|#youtubeshorts|youtubeshorts|#youtube| youtube|#shortsyoutube|#ytshorts|ytshorts|#ytshort|#shortvideo|shortvideo|#shortsfeed|#tiktok|tiktok|#tiktokchallenge|#myfirstshorts|#myfirstshort|#viral|viralvideo|viral|#viralshorts|#virlshort|#ytviralshorts', + ' ', text) + text = re.sub(r' s |short|youtube|virlshort|#', ' ', text) + pattern = r'[^a-zA-Z0-9\s\.,;:?!\'\"|]' + if is_tags: + pattern = r'[^a-zA-Z0-9\s]' + text = re.sub(pattern, '', text) + text = whitespace_clean(basic_clean(text)) + return text + +def load_and_transform_text(text, tokenizer, title=True): + if title: + title_hashtags = text.split('#') + title, hashtags = title_hashtags[0], '#' + '#'.join(title_hashtags[1:]) + title = clean_youtube(title) + hashtags = clean_youtube(hashtags, is_tags=True) + text = title + ', ' + hashtags + if text == '' or text.isspace(): + raise ValueError('text is empty') + input_ids, attention_mask = tokenizer(text) + return {'input_ids': input_ids.squeeze(), 'attention_mask': attention_mask.squeeze()} + + + +if __name__ == '__main__': + load_and_transform_text("bpe/bpe_simple_vocab_16e6.txt.gz") \ No newline at end of file diff --git a/data/process_thermal.py b/data/process_thermal.py new file mode 100644 index 0000000000000000000000000000000000000000..8e26870dda6fd7bf6c8576326f13d161073b63a8 --- /dev/null +++ b/data/process_thermal.py @@ -0,0 +1,26 @@ +import PIL +import cv2 +import numpy as np +import torch +from PIL import Image +from torch import nn +from torchvision import transforms +from open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + + +def get_thermal_transform(args): + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD) # assume image + ] + ) + return transform + +def load_and_transform_thermal(thermal_path, transform): + thermal = Image.open(thermal_path) + thermal_outputs = transform(thermal) + return {'pixel_values': thermal_outputs} diff --git a/data/process_video.py b/data/process_video.py new file mode 100644 index 0000000000000000000000000000000000000000..1b7171be0318349003b9cd20c3503fbdf24db36d --- /dev/null +++ b/data/process_video.py @@ -0,0 +1,161 @@ + +import io +import logging +import os + +import cv2 +import numpy as np +import torch +import decord +import torchvision.transforms +from PIL import Image +from decord import VideoReader, cpu + +try: + from petrel_client.client import Client + petrel_backend_imported = True +except (ImportError, ModuleNotFoundError): + petrel_backend_imported = False + + +from pytorchvideo.data.encoded_video import EncodedVideo +from torchvision.transforms import Compose, Lambda, ToTensor +from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo +from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample +import sys +sys.path.append('../') +from open_clip import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from os.path import join as opj + + +def get_video_loader(use_petrel_backend: bool = True, + enable_mc: bool = True, + conf_path: str = None): + if petrel_backend_imported and use_petrel_backend: + _client = Client(conf_path=conf_path, enable_mc=enable_mc) + else: + _client = None + + def _loader(video_path): + if _client is not None and 's3:' in video_path: + video_path = io.BytesIO(_client.get(video_path)) + + vr = VideoReader(video_path, num_threads=1, ctx=cpu(0)) + return vr + + return _loader + + +decord.bridge.set_bridge('torch') +# video_loader = get_video_loader() + + +def get_video_transform(args): + if args.video_decode_backend == 'pytorchvideo': + transform = ApplyTransformToKey( + key="video", + transform=Compose( + [ + UniformTemporalSubsample(args.num_frames), + Lambda(lambda x: x / 255.0), + NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), + ShortSideScale(size=224), + RandomCropVideo(size=224), + RandomHorizontalFlipVideo(p=0.5), + ] + ), + ) + + elif args.video_decode_backend == 'decord': + + transform = Compose( + [ + # UniformTemporalSubsample(num_frames), + Lambda(lambda x: x / 255.0), + NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), + ShortSideScale(size=224), + RandomCropVideo(size=224), + RandomHorizontalFlipVideo(p=0.5), + ] + ) + + elif args.video_decode_backend == 'opencv': + transform = Compose( + [ + # UniformTemporalSubsample(num_frames), + Lambda(lambda x: x / 255.0), + NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), + ShortSideScale(size=224), + RandomCropVideo(size=224), + RandomHorizontalFlipVideo(p=0.5), + ] + ) + + elif args.video_decode_backend == 'imgs': + transform = Compose( + [ + # UniformTemporalSubsample(num_frames), + # Lambda(lambda x: x / 255.0), + NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), + ShortSideScale(size=224), + RandomCropVideo(size=224), + RandomHorizontalFlipVideo(p=0.5), + ] + ) + else: + raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv, imgs)') + return transform + +def load_and_transform_video( + video_path, + transform, + video_decode_backend='opencv', + clip_start_sec=0.0, + clip_end_sec=None, + num_frames=8, +): + if video_decode_backend == 'pytorchvideo': + # decord pyav + video = EncodedVideo.from_path(video_path, decoder="decord", decode_audio=False) + duration = video.duration + start_sec = clip_start_sec # secs + end_sec = clip_end_sec if clip_end_sec is not None else duration # secs + video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec) + video_outputs = transform(video_data) + + elif video_decode_backend == 'decord': + decord_vr = VideoReader(video_path, ctx=cpu(0)) + duration = len(decord_vr) + frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) + video_data = decord_vr.get_batch(frame_id_list) + video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) + video_outputs = transform(video_data) + + elif video_decode_backend == 'opencv': + cv2_vr = cv2.VideoCapture(video_path) + duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) + + video_data = [] + for frame_idx in frame_id_list: + cv2_vr.set(1, frame_idx) + _, frame = cv2_vr.read() + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + video_data.append(torch.from_numpy(frame).permute(2, 0, 1)) + cv2_vr.release() + video_data = torch.stack(video_data, dim=1) + video_outputs = transform(video_data) + + elif video_decode_backend == 'imgs': + resize256_folder = video_path.replace('.mp4', '_resize256_folder') + video_data = [ToTensor()(Image.open(opj(resize256_folder, f'{i}.jpg'))) for i in range(8)] + video_data = torch.stack(video_data, dim=1) + # print(video_data.shape, video_data.max(), video_data.min()) + video_outputs = transform(video_data) + + else: + raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv, imgs)') + return {'pixel_values': video_outputs} + +if __name__ == '__main__': + load_and_transform_video(r"D:\ONE-PEACE-main\lb_test\zHSOYcZblvY.mp4") \ No newline at end of file diff --git a/gradio_app.py b/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..a835487290a2385d58ea91ffbbad5d32a1ca125c --- /dev/null +++ b/gradio_app.py @@ -0,0 +1,219 @@ +import sys + +import gradio as gr +import argparse +import numpy as np +import torch +from torch import nn + +from languagebind import LanguageBind, transform_dict, LanguageBindImageTokenizer, to_device + +code_highlight_css = ( +""" +#chatbot .hll { background-color: #ffffcc } +#chatbot .c { color: #408080; font-style: italic } +#chatbot .err { border: 1px solid #FF0000 } +#chatbot .k { color: #008000; font-weight: bold } +#chatbot .o { color: #666666 } +#chatbot .ch { color: #408080; font-style: italic } +#chatbot .cm { color: #408080; font-style: italic } +#chatbot .cp { color: #BC7A00 } +#chatbot .cpf { color: #408080; font-style: italic } +#chatbot .c1 { color: #408080; font-style: italic } +#chatbot .cs { color: #408080; font-style: italic } +#chatbot .gd { color: #A00000 } +#chatbot .ge { font-style: italic } +#chatbot .gr { color: #FF0000 } +#chatbot .gh { color: #000080; font-weight: bold } +#chatbot .gi { color: #00A000 } +#chatbot .go { color: #888888 } +#chatbot .gp { color: #000080; font-weight: bold } +#chatbot .gs { font-weight: bold } +#chatbot .gu { color: #800080; font-weight: bold } +#chatbot .gt { color: #0044DD } +#chatbot .kc { color: #008000; font-weight: bold } +#chatbot .kd { color: #008000; font-weight: bold } +#chatbot .kn { color: #008000; font-weight: bold } +#chatbot .kp { color: #008000 } +#chatbot .kr { color: #008000; font-weight: bold } +#chatbot .kt { color: #B00040 } +#chatbot .m { color: #666666 } +#chatbot .s { color: #BA2121 } +#chatbot .na { color: #7D9029 } +#chatbot .nb { color: #008000 } +#chatbot .nc { color: #0000FF; font-weight: bold } +#chatbot .no { color: #880000 } +#chatbot .nd { color: #AA22FF } +#chatbot .ni { color: #999999; font-weight: bold } +#chatbot .ne { color: #D2413A; font-weight: bold } +#chatbot .nf { color: #0000FF } +#chatbot .nl { color: #A0A000 } +#chatbot .nn { color: #0000FF; font-weight: bold } +#chatbot .nt { color: #008000; font-weight: bold } +#chatbot .nv { color: #19177C } +#chatbot .ow { color: #AA22FF; font-weight: bold } +#chatbot .w { color: #bbbbbb } +#chatbot .mb { color: #666666 } +#chatbot .mf { color: #666666 } +#chatbot .mh { color: #666666 } +#chatbot .mi { color: #666666 } +#chatbot .mo { color: #666666 } +#chatbot .sa { color: #BA2121 } +#chatbot .sb { color: #BA2121 } +#chatbot .sc { color: #BA2121 } +#chatbot .dl { color: #BA2121 } +#chatbot .sd { color: #BA2121; font-style: italic } +#chatbot .s2 { color: #BA2121 } +#chatbot .se { color: #BB6622; font-weight: bold } +#chatbot .sh { color: #BA2121 } +#chatbot .si { color: #BB6688; font-weight: bold } +#chatbot .sx { color: #008000 } +#chatbot .sr { color: #BB6688 } +#chatbot .s1 { color: #BA2121 } +#chatbot .ss { color: #19177C } +#chatbot .bp { color: #008000 } +#chatbot .fm { color: #0000FF } +#chatbot .vc { color: #19177C } +#chatbot .vg { color: #19177C } +#chatbot .vi { color: #19177C } +#chatbot .vm { color: #19177C } +#chatbot .il { color: #666666 } +""") +#.highlight { background: #f8f8f8; } + +title_markdown = (""" +
+ + LanguageBind🚀 + + + + +
+

LanguageBind: Extending Video-Language Pretraining to N-modality by Language-based Semantic Alignment

+ +
If you like our project, please give us a star ✨ on Github for latest update.
+ +
+
+ + + +
+
+""") +css = code_highlight_css + """ +pre { + white-space: pre-wrap; /* Since CSS 2.1 */ + white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ + white-space: -pre-wrap; /* Opera 4-6 */ + white-space: -o-pre-wrap; /* Opera 7 */ + word-wrap: break-word; /* Internet Explorer 5.5+ */ +} +""" + + +def image_to_language(image, language): + inputs = {} + inputs['image'] = to_device(modality_transform['image'](image), device) + inputs['language'] = to_device(modality_transform['language'](language, max_length=77, padding='max_length', + truncation=True, return_tensors='pt'), device) + with torch.no_grad(): + embeddings = model(inputs) + return (embeddings['image'] @ embeddings['language'].T).item() + + +def video_to_language(video, language): + inputs = {} + inputs['video'] = to_device(modality_transform['video'](video), device) + inputs['language'] = to_device(modality_transform['language'](language, max_length=77, padding='max_length', + truncation=True, return_tensors='pt'), device) + with torch.no_grad(): + embeddings = model(inputs) + return (embeddings['video'] @ embeddings['language'].T).item() + + +def audio_to_language(audio, language): + inputs = {} + inputs['audio'] = to_device(modality_transform['audio'](audio), device) + inputs['language'] = to_device(modality_transform['language'](language, max_length=77, padding='max_length', + truncation=True, return_tensors='pt'), device) + with torch.no_grad(): + embeddings = model(inputs) + return (embeddings['audio'] @ embeddings['language'].T).item() + + +def depth_to_language(depth, language): + inputs = {} + inputs['depth'] = to_device(modality_transform['depth'](depth.name), device) + inputs['language'] = to_device(modality_transform['language'](language, max_length=77, padding='max_length', + truncation=True, return_tensors='pt'), device) + with torch.no_grad(): + embeddings = model(inputs) + return (embeddings['depth'] @ embeddings['language'].T).item() + + +def thermal_to_language(thermal, language): + inputs = {} + inputs['thermal'] = to_device(modality_transform['thermal'](thermal), device) + inputs['language'] = to_device(modality_transform['language'](language, max_length=77, padding='max_length', + truncation=True, return_tensors='pt'), device) + with torch.no_grad(): + embeddings = model(inputs) + return (embeddings['thermal'] @ embeddings['language'].T).item() + +if __name__ == '__main__': + device = 'cuda:0' + device = torch.device(device) + clip_type = { + 'video': 'LanguageBind_Video_FT', # also LanguageBind_Video + 'audio': 'LanguageBind_Audio_FT', # also LanguageBind_Audio + 'thermal': 'LanguageBind_Thermal', + 'image': 'LanguageBind_Image', + 'depth': 'LanguageBind_Depth', + } + model = LanguageBind(clip_type=clip_type, use_temp=False) + model = model.to(device) + model.eval() + pretrained_ckpt = f'lb203/LanguageBind_Image' + tokenizer = LanguageBindImageTokenizer.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir/tokenizer_cache_dir') + modality_transform = {c: transform_dict[c](model.modality_config[c]) for c in clip_type} + modality_transform['language'] = tokenizer + + with gr.Blocks(title="LanguageBind🚀", css=css) as demo: + gr.Markdown(title_markdown) + with gr.Row(): + with gr.Column(): + image = gr.Image(type="filepath", height=224, width=224, label='Image Input') + language_i = gr.Textbox(lines=2, label='Text Input') + out_i = gr.Textbox(label='Similarity of Image to Text') + b_i = gr.Button("Calculate similarity of Image to Text") + with gr.Column(): + video = gr.Video(type="filepath", height=224, width=224, label='Video Input') + language_v = gr.Textbox(lines=2, label='Text Input') + out_v = gr.Textbox(label='Similarity of Video to Text') + b_v = gr.Button("Calculate similarity of Video to Text") + with gr.Column(): + audio = gr.Audio(type="filepath", label='Audio Input') + language_a = gr.Textbox(lines=2, label='Text Input') + out_a = gr.Textbox(label='Similarity of Audio to Text') + b_a = gr.Button("Calculate similarity of Audio to Text") + with gr.Row(): + with gr.Column(): + depth = gr.File(height=224, width=224, label='Depth Input, need a .png file, 16 bit, with values ranging from 0-10000 (representing 0-10 metres, but 1000 times)') + language_d = gr.Textbox(lines=2, label='Text Input') + out_d = gr.Textbox(label='Similarity of Depth to Text') + b_d = gr.Button("Calculate similarity of Depth to Text") + with gr.Column(): + thermal = gr.Image(type="filepath", height=224, width=224, label='Thermal Input, you should first convert to RGB') + language_t = gr.Textbox(lines=2, label='Text Input') + out_t = gr.Textbox(label='Similarity of Thermal to Text') + b_t = gr.Button("Calculate similarity of Thermal to Text") + + b_i.click(image_to_language, inputs=[image, language_i], outputs=out_i) + b_a.click(audio_to_language, inputs=[audio, language_a], outputs=out_a) + b_v.click(video_to_language, inputs=[video, language_v], outputs=out_v) + b_d.click(depth_to_language, inputs=[depth, language_d], outputs=out_d) + b_t.click(thermal_to_language, inputs=[thermal, language_t], outputs=out_t) + + demo.launch() diff --git a/i_cls/datasets.py b/i_cls/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..1c30c0e9b908d2af2afdf3dd989f7d2f2db5e21c --- /dev/null +++ b/i_cls/datasets.py @@ -0,0 +1,31 @@ +import torch + +from data.build_datasets import DataInfo +from open_clip import image_transform, OPENAI_DATASET_STD, OPENAI_DATASET_MEAN, get_tokenizer +from torchvision import datasets + + +def get_imagenet(args, split): + assert split in ["val", "v2"] + preprocess_val = image_transform( + args.image_size, + is_train=False, + mean=OPENAI_DATASET_MEAN, + std=OPENAI_DATASET_STD, + ) + if split == "v2": + from imagenetv2_pytorch import ImageNetV2Dataset + dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) + else: + data_path = args.imagenet_val + assert data_path + dataset = datasets.ImageFolder(data_path, transform=preprocess_val) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=None, + ) + + return DataInfo(dataloader=dataloader, sampler=None) diff --git a/i_cls/precision.py b/i_cls/precision.py new file mode 100644 index 0000000000000000000000000000000000000000..a63b92256518d13afd57261df1568e26b1622201 --- /dev/null +++ b/i_cls/precision.py @@ -0,0 +1,12 @@ +import torch +from contextlib import suppress + + +def get_autocast(precision): + if precision == 'amp': + return torch.cuda.amp.autocast + elif precision == 'amp_bfloat16' or precision == 'amp_bf16': + # amp_bfloat16 is more stable than amp float16 for clip training + return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return suppress diff --git a/i_cls/zero_shot.py b/i_cls/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..895acff9afc34e4b463ce4fdf5dacdb1eaff24b3 --- /dev/null +++ b/i_cls/zero_shot.py @@ -0,0 +1,87 @@ +import logging + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ + IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES +from open_clip.factory import HF_HUB_PREFIX +from .precision import get_autocast + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] + + +def run(model, classifier, dataloader, args): + autocast = get_autocast(args.precision) + input_dtype = get_input_dtype(args.precision) + + with torch.no_grad(): + top1, top5, n = 0., 0., 0. + for images, target in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(device=args.device, dtype=input_dtype) + images = images.unsqueeze(2) + target = target.to(args.device) + + with autocast(): + # predict + output = model(image=images) + image_features = output['image_features'] if isinstance(output, dict) else output[0] + logits = 100. * image_features @ classifier + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = (top1 / n) + top5 = (top5 / n) + return top1, top5 + + +def zero_shot_eval(model, data, epoch, args): + if 'imagenet-val' not in data and 'imagenet-v2' not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + if args.distributed and not args.horovod: + model = model.module + + logging.info('Starting zero-shot imagenet.') + + logging.info('Building zero-shot classifier') + autocast = get_autocast(args.precision) + with autocast(): + tokenizer = get_tokenizer(HF_HUB_PREFIX+args.model, cache_dir=args.cache_dir) + # tokenizer = get_tokenizer("ViT-L-14") + classifier = build_zero_shot_classifier( + model, + tokenizer=tokenizer, + classnames=IMAGENET_CLASSNAMES, + templates=OPENAI_IMAGENET_TEMPLATES, + num_classes_per_batch=10, + device=args.device, + use_tqdm=True, + ) + + logging.info('Using classifier') + results = {} + if 'imagenet-val' in data: + top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) + results['imagenet-zeroshot-val-top1'] = top1 + results['imagenet-zeroshot-val-top5'] = top5 + if 'imagenet-v2' in data: + top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) + results['imagenetv2-zeroshot-val-top1'] = top1 + results['imagenetv2-zeroshot-val-top5'] = top5 + + logging.info('Finished zero-shot imagenet.') + + return results diff --git a/i_cls/zeroshot_cls.py b/i_cls/zeroshot_cls.py new file mode 100644 index 0000000000000000000000000000000000000000..874de6616ffec5987451c6b9e29a38536c04c3f7 --- /dev/null +++ b/i_cls/zeroshot_cls.py @@ -0,0 +1,47 @@ + +import json +import logging +import os +from training.distributed import is_master +from .zero_shot import zero_shot_eval + +try: + import wandb +except ImportError: + wandb = None + + + +def evaluate_i_cls(model, data, epoch, args, tb_writer=None): + metrics = {} + if not is_master(args): + return metrics + model.eval() + + zero_shot_metrics = zero_shot_eval(model, data['i_cls'], epoch, args) + metrics.update(zero_shot_metrics) + + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) + ) + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/i_cls/{name}", val, epoch) + args.i_cls_output_dir = os.path.join(args.log_base_path, 'i_cls') + os.makedirs(args.i_cls_output_dir, exist_ok=True) + with open(os.path.join(args.i_cls_output_dir, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, 'Please install wandb.' + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, 'epoch': epoch}) + + return metrics diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e3047e996a7271f323eede736e18e4af29bb84b6 --- /dev/null +++ b/inference.py @@ -0,0 +1,59 @@ +import torch +from languagebind import LanguageBind, to_device, transform_dict, LanguageBindImageTokenizer + +if __name__ == '__main__': + device = 'cuda:0' + device = torch.device(device) + clip_type = { + 'video': 'LanguageBind_Video_FT', # also LanguageBind_Video + 'audio': 'LanguageBind_Audio_FT', # also LanguageBind_Audio + 'thermal': 'LanguageBind_Thermal', + 'image': 'LanguageBind_Image', + 'depth': 'LanguageBind_Depth', + } + + model = LanguageBind(clip_type=clip_type, cache_dir='./cache_dir') + model = model.to(device) + model.eval() + pretrained_ckpt = f'LanguageBind/LanguageBind_Image' + tokenizer = LanguageBindImageTokenizer.from_pretrained(pretrained_ckpt, cache_dir='./cache_dir/tokenizer_cache_dir') + modality_transform = {c: transform_dict[c](model.modality_config[c]) for c in clip_type.keys()} + + image = ['assets/image/0.jpg', 'assets/image/1.jpg'] + audio = ['assets/audio/0.wav', 'assets/audio/1.wav'] + video = ['assets/video/0.mp4', 'assets/video/1.mp4'] + depth = ['assets/depth/0.png', 'assets/depth/1.png'] + thermal = ['assets/thermal/0.jpg', 'assets/thermal/1.jpg'] + language = ["Training a parakeet to climb up a ladder.", 'A lion climbing a tree to catch a monkey.'] + + inputs = { + 'image': to_device(modality_transform['image'](image), device), + 'video': to_device(modality_transform['video'](video), device), + 'audio': to_device(modality_transform['audio'](audio), device), + 'depth': to_device(modality_transform['depth'](depth), device), + 'thermal': to_device(modality_transform['thermal'](thermal), device), + } + inputs['language'] = to_device(tokenizer(language, max_length=77, padding='max_length', + truncation=True, return_tensors='pt'), device) + + with torch.no_grad(): + embeddings = model(inputs) + + print("Video x Text: \n", + torch.softmax(embeddings['video'] @ embeddings['language'].T, dim=-1).detach().cpu().numpy()) + print("Image x Text: \n", + torch.softmax(embeddings['image'] @ embeddings['language'].T, dim=-1).detach().cpu().numpy()) + print("Depth x Text: \n", + torch.softmax(embeddings['depth'] @ embeddings['language'].T, dim=-1).detach().cpu().numpy()) + print("Audio x Text: \n", + torch.softmax(embeddings['audio'] @ embeddings['language'].T, dim=-1).detach().cpu().numpy()) + print("Thermal x Text: \n", + torch.softmax(embeddings['thermal'] @ embeddings['language'].T, dim=-1).detach().cpu().numpy()) + + print("Video x Audio: \n", + torch.softmax(embeddings['video'] @ embeddings['audio'].T, dim=-1).detach().cpu().numpy()) + print("Image x Depth: \n", + torch.softmax(embeddings['image'] @ embeddings['depth'].T, dim=-1).detach().cpu().numpy()) + print("Image x Thermal: \n", + torch.softmax(embeddings['image'] @ embeddings['thermal'].T, dim=-1).detach().cpu().numpy()) + diff --git a/languagebind/__init__.py b/languagebind/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31f1a3fc69b9c78bd2d4feaebe3dbc1a923e4264 --- /dev/null +++ b/languagebind/__init__.py @@ -0,0 +1,90 @@ +import torch +from torch import nn +from transformers import AutoConfig + +from .image.configuration_image import LanguageBindImageConfig +from .image.modeling_image import LanguageBindImage +from .image.tokenization_image import LanguageBindImageTokenizer +from .image.processing_image import LanguageBindImageProcessor + +from .video.configuration_video import LanguageBindVideoConfig +from .video.modeling_video import LanguageBindVideo +from .video.tokenization_video import LanguageBindVideoTokenizer +from .video.processing_video import LanguageBindVideoProcessor + +from .depth.configuration_depth import LanguageBindDepthConfig +from .depth.modeling_depth import LanguageBindDepth +from .depth.tokenization_depth import LanguageBindDepthTokenizer +from .depth.processing_depth import LanguageBindDepthProcessor + +from .audio.configuration_audio import LanguageBindAudioConfig +from .audio.modeling_audio import LanguageBindAudio +from .audio.tokenization_audio import LanguageBindAudioTokenizer +from .audio.processing_audio import LanguageBindAudioProcessor + +from .thermal.configuration_thermal import LanguageBindThermalConfig +from .thermal.modeling_thermal import LanguageBindThermal +from .thermal.tokenization_thermal import LanguageBindThermalTokenizer +from .thermal.processing_thermal import LanguageBindThermalProcessor + + + +config_dict = { + 'thermal': LanguageBindThermalConfig, + 'image': LanguageBindImageConfig, + 'video': LanguageBindVideoConfig, + 'depth': LanguageBindDepthConfig, + 'audio': LanguageBindAudioConfig +} +model_dict = { + 'thermal': LanguageBindThermal, + 'image': LanguageBindImage, + 'video': LanguageBindVideo, + 'depth': LanguageBindDepth, + 'audio': LanguageBindAudio +} +transform_dict = { + 'video': LanguageBindVideoProcessor, + 'audio': LanguageBindAudioProcessor, + 'depth': LanguageBindDepthProcessor, + 'thermal': LanguageBindThermalProcessor, + 'image': LanguageBindImageProcessor, +} + +class LanguageBind(nn.Module): + def __init__(self, clip_type, use_temp=True, cache_dir='./cache_dir'): + super(LanguageBind, self).__init__() + self.use_temp = use_temp + self.modality_encoder = {} + self.modality_proj = {} + self.modality_scale = {} + self.modality_config = {} + for k, v in clip_type.items(): + pretrained_ckpt = f'LanguageBind/{v}' + model = model_dict[k].from_pretrained(pretrained_ckpt, cache_dir=cache_dir) + self.modality_encoder[k] = model.vision_model + self.modality_proj[k] = model.visual_projection + self.modality_scale[k] = model.logit_scale + self.modality_config[k] = model.config + self.modality_encoder['language'] = model.text_model + self.modality_proj['language'] = model.text_projection + + self.modality_encoder = nn.ModuleDict(self.modality_encoder) + self.modality_proj = nn.ModuleDict(self.modality_proj) + + def forward(self, inputs): + outputs = {} + for key, value in inputs.items(): + value = self.modality_encoder[key](**value)[1] + value = self.modality_proj[key](value) + value = value / value.norm(p=2, dim=-1, keepdim=True) + if self.use_temp: + if key != 'language': + value = value * self.modality_scale[key].exp() + outputs[key] = value + return outputs + +def to_device(x, device): + out_dict = {k: v.to(device) for k, v in x.items()} + return out_dict + diff --git a/languagebind/audio/configuration_audio.py b/languagebind/audio/configuration_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..865a496cff50fbac855413220288cd996965468b --- /dev/null +++ b/languagebind/audio/configuration_audio.py @@ -0,0 +1,430 @@ +import copy +import os +from typing import Union + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + + + + + + +class CLIPTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP + text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the text encoder of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`CLIPModel`]. + hidden_size (`int`, *optional*, defaults to 512): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 77): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import CLIPTextConfig, CLIPTextModel + + >>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPTextConfig() + + >>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "clip_text_model" + + def __init__( + self, + vocab_size=49408, + hidden_size=512, + intermediate_size=2048, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=8, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + # This differs from `CLIPTokenizer`'s default and from openai/clip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.add_time_attn = False ###################################### + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "clip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + + + +class CLIPVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a + CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import CLIPVisionConfig, CLIPVisionModel + + >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPVisionConfig() + + >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "clip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=32, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + + add_time_attn=False, ################################ + num_frames=1, ################################ + force_patch_dropout=0.0, ################################ + lora_r=2, ################################ + lora_alpha=16, ################################ + lora_dropout=0.0, ################################ + num_mel_bins=0.0, ################################ + target_length=0.0, ################################ + video_decode_backend='decord', ######################### + audio_sample_rate=16000, + audio_mean=0.5, + audio_std=0.5, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + self.add_time_attn = add_time_attn ################ + self.num_frames = num_frames ################ + self.force_patch_dropout = force_patch_dropout ################ + self.lora_r = lora_r ################ + self.lora_alpha = lora_alpha ################ + self.lora_dropout = lora_dropout ################ + self.num_mel_bins = num_mel_bins ################ + self.target_length = target_length ################ + self.video_decode_backend = video_decode_backend ################ + + self.audio_sample_rate = audio_sample_rate + self.audio_mean = audio_mean + self.audio_std = audio_std + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "clip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class LanguageBindAudioConfig(PretrainedConfig): + r""" + [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate + a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating + a configuration with the defaults will yield a similar configuration to that of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`CLIPTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`CLIPVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimentionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import CLIPConfig, CLIPModel + + >>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPConfig() + + >>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig + >>> from transformers import CLIPTextConfig, CLIPVisionConfig + + >>> # Initializing a CLIPText and CLIPVision configuration + >>> config_text = CLIPTextConfig() + >>> config_vision = CLIPVisionConfig() + + >>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "LanguageBindAudio" + is_composition = True + + def __init__( + self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs + ): + # If `_config_dict` exist, we use them for the backward compatibility. + # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot + # of confusion!). + text_config_dict = kwargs.pop("text_config_dict", None) + vision_config_dict = kwargs.pop("vision_config_dict", None) + + super().__init__(**kwargs) + + # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in + # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most + # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`. + if text_config_dict is not None: + if text_config is None: + text_config = {} + + # This is the complete result when using `text_config_dict`. + _text_config_dict = CLIPTextConfig(**text_config_dict).to_dict() + + # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different. + for key, value in _text_config_dict.items(): + if key in text_config and value != text_config[key] and key not in ["transformers_version"]: + # If specified in `text_config_dict` + if key in text_config_dict: + message = ( + f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. " + f'The value `text_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The " + f'value `text_config["{key}"]` will be overriden.' + ) + logger.warning(message) + + # Update all values in `text_config` with the ones in `_text_config_dict`. + text_config.update(_text_config_dict) + + if vision_config_dict is not None: + if vision_config is None: + vision_config = {} + + # This is the complete result when using `vision_config_dict`. + _vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict() + # convert keys to string instead of integer + if "id2label" in _vision_config_dict: + _vision_config_dict["id2label"] = { + str(key): value for key, value in _vision_config_dict["id2label"].items() + } + + # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different. + for key, value in _vision_config_dict.items(): + if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]: + # If specified in `vision_config_dict` + if key in vision_config_dict: + message = ( + f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different " + f'values. The value `vision_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. " + f'The value `vision_config["{key}"]` will be overriden.' + ) + logger.warning(message) + + # Update all values in `vision_config` with the ones in `_vision_config_dict`. + vision_config.update(_vision_config_dict) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.") + + self.text_config = CLIPTextConfig(**text_config) + self.vision_config = CLIPVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs): + r""" + Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model + configuration. + + Returns: + [`CLIPConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["text_config"] = self.text_config.to_dict() + output["vision_config"] = self.vision_config.to_dict() + output["model_type"] = self.__class__.model_type + return output + + + + + + + + + + diff --git a/languagebind/audio/modeling_audio.py b/languagebind/audio/modeling_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..908ab43e852ccfbdf3a6b4e7546b9f0d11aac78e --- /dev/null +++ b/languagebind/audio/modeling_audio.py @@ -0,0 +1,1030 @@ +import math +from typing import Optional, Tuple, Union + +import torch +from einops import rearrange +from peft import LoraConfig, get_peft_model +from torch import nn +from torch.nn import functional as F +from transformers import PreTrainedModel, add_start_docstrings +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings + +from .configuration_audio import LanguageBindAudioConfig, CLIPVisionConfig, CLIPTextConfig + + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x, B, T): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + if T == 1: + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + else: + rand = torch.randn(B, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1) + patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n') + + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + +class CLIPEncoderLayer(nn.Module): + def __init__(self, config: LanguageBindAudioConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + self.add_time_attn = config.add_time_attn + if self.add_time_attn: + self.t = config.num_frames + self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size)) + nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5) + + self.embed_dim = config.hidden_size + self.temporal_attn = CLIPAttention(config) + self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.temporal_mlp = CLIPMLP(config) + self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + + + if self.add_time_attn: + bt, n, d = hidden_states.shape + t = self.t + + # time embed + if t != 1: + n = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + hidden_states = hidden_states + self.temporal_embedding[:, :t, :] + hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # time attn + residual = hidden_states + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # hidden_states = self.layer_norm1(hidden_states) # share layernorm + hidden_states = self.temporal_layer_norm1(hidden_states) + hidden_states, attn_weights = self.temporal_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + residual = hidden_states + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # hidden_states = self.layer_norm2(hidden_states) # share layernorm + hidden_states = self.temporal_layer_norm2(hidden_states) + hidden_states = self.temporal_mlp(hidden_states) + hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # spatial attn + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + + + + + + + + +class CLIPPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LanguageBindAudioConfig + base_model_prefix = "clip" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, CLIPTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, CLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, CLIPAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, CLIPMLP): + factor = self.config.initializer_factor + in_proj_std = ( + (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + ) + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, LanguageBindAudio): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPVisionModelWithProjection): + nn.init.normal_( + module.visual_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPTextModelWithProjection): + nn.init.normal_( + module.text_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CLIPEncoder): + module.gradient_checkpointing = value + + +CLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, config: LanguageBindAudioConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + causal_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +class CLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = CLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class CLIPVisionTransformer(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + self.patch_dropout = PatchDropout(config.force_patch_dropout) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = CLIPEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + ###################################### + if len(pixel_values.shape) == 7: + b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape + # print(pixel_values.shape) + B = b_new * pair_new * bs_new + pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new) + + elif len(pixel_values.shape) == 5: + B, _, T, _, _ = pixel_values.shape + # print(pixel_values.shape) + pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w') + else: + # print(pixel_values.shape) + B, _, _, _ = pixel_values.shape + T = 1 + ########################### + hidden_states = self.embeddings(pixel_values) + + hidden_states = self.patch_dropout(hidden_states, B, T) ############################################## + + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################ + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The vision model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPVisionModel(CLIPPreTrainedModel): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: CLIPVisionConfig): + super().__init__(config) + self.vision_model = CLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPVisionModel + + >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(CLIP_START_DOCSTRING) +class LanguageBindAudio(CLIPPreTrainedModel): + config_class = LanguageBindAudioConfig + + def __init__(self, config: LanguageBindAudioConfig): + super().__init__(config) + + if not isinstance(config.text_config, CLIPTextConfig): + raise ValueError( + "config.text_config is expected to be of type CLIPTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, CLIPVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type CLIPVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + self.add_time_attn = vision_config.add_time_attn + self.lora_r = vision_config.lora_r + self.lora_alpha = vision_config.lora_alpha + self.lora_dropout = vision_config.lora_dropout + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer(text_config) + self.vision_model = CLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + self.convert_to_lora() + self.resize_pos(self.vision_model.embeddings, vision_config) + + def convert_to_lora(self): + if self.lora_r == 0: + return + if self.add_time_attn: + target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj", + "temporal_attn.q_proj", "temporal_attn.out_proj", + "temporal_mlp.fc1", "temporal_mlp.fc2"] + else: + target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"] + config = LoraConfig( + r=self.lora_r, # 16 + lora_alpha=self.lora_alpha, # 16 + target_modules=target_modules, # self_attn.out_proj + lora_dropout=self.lora_dropout, # 0.1 + bias="none", + modules_to_save=[], + ) + self.vision_model.encoder.is_gradient_checkpointing = False + self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config) + + def resize_pos(self, m, vision_config): + # convert embedding + if vision_config.num_mel_bins!=0 and vision_config.target_length!=0: + m.image_size = [vision_config.num_mel_bins, vision_config.target_length] + m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size + # pos resize + old_pos_embed_state_dict = m.position_embedding.state_dict() + old_pos_embed = old_pos_embed_state_dict['weight'] + dtype = old_pos_embed.dtype + grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size] + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + # m.to(args.device) + return + + m.num_patches = grid_size[0] * grid_size[1] + m.num_positions = m.num_patches + 1 + m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1))) + new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim) + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2 + + # if is_master(args): + # logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode='bicubic', + antialias=True, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype) + m.position_embedding = new_position_embedding + m.position_embedding.load_state_dict(old_pos_embed_state_dict) + + # m.to(args.device) + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`CLIPTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`CLIPVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindAudioConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CLIPOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = clip_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return CLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) \ No newline at end of file diff --git a/languagebind/audio/processing_audio.py b/languagebind/audio/processing_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad3e6c3edef3998e4d5bf8a42b9c72fd484e1e3 --- /dev/null +++ b/languagebind/audio/processing_audio.py @@ -0,0 +1,171 @@ +import cv2 +import numpy as np +import torch +import torchaudio +from torchvision import transforms +from transformers import ProcessorMixin, BatchEncoding +from transformers.image_processing_utils import BatchFeature +from torch.nn import functional as F + + +def make_list_of_images(x): + if not isinstance(x, list): + return [x] + return x + + +torchaudio.set_audio_backend("soundfile") + +def torchaudio_loader(path): + return torchaudio.load(path) + +def int16_to_float32_torch(x): + return (x / 32767.0).type(torch.float32) + +def float32_to_int16_torch(x): + x = torch.clamp(x, min=-1., max=1.) + return (x * 32767.).type(torch.int16) + +DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 + +class AudioTransform: + def __init__(self, args): + self.sample_rate = args.audio_sample_rate + self.num_mel_bins = args.num_mel_bins + self.target_length = args.target_length + self.audio_mean = args.audio_mean + self.audio_std = args.audio_std + self.mean = [] + self.std = [] + # mean=-4.2677393 + # std=4.5689974 + # self.norm = transforms.Normalize(mean=self.audio_mean, std=self.audio_std) + + + def __call__(self, audio_data_and_origin_sr): + audio_data, origin_sr = audio_data_and_origin_sr + if self.sample_rate != origin_sr: + # print(audio_data.shape, origin_sr) + audio_data = torchaudio.functional.resample(audio_data, orig_freq=origin_sr, new_freq=self.sample_rate) + waveform_melspec = self.waveform2melspec(audio_data) + return waveform_melspec + + + def waveform2melspec(self, audio_data): + mel = self.get_mel(audio_data) + if mel.shape[0] > self.target_length: + # split to three parts + chunk_frames = self.target_length + total_frames = mel.shape[0] + ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) + # print('total_frames-chunk_frames:', total_frames-chunk_frames, + # 'len(audio_data):', len(audio_data), + # 'chunk_frames:', chunk_frames, + # 'total_frames:', total_frames) + if len(ranges[1]) == 0: # if the audio is too short, we just use the first chunk + ranges[1] = [0] + if len(ranges[2]) == 0: # if the audio is too short, we just use the first chunk + ranges[2] = [0] + # randomly choose index for each part + idx_front = np.random.choice(ranges[0]) + idx_middle = np.random.choice(ranges[1]) + idx_back = np.random.choice(ranges[2]) + # idx_front = ranges[0][0] # fixed + # idx_middle = ranges[1][0] + # idx_back = ranges[2][0] + # select mel + mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :] + mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :] + mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :] + # print(total_frames, idx_front, idx_front + chunk_frames, idx_middle, idx_middle + chunk_frames, idx_back, idx_back + chunk_frames) + # stack + mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0) + elif mel.shape[0] < self.target_length: # padding if too short + n_repeat = int(self.target_length / mel.shape[0]) + 1 + # print(self.target_length, mel.shape[0], n_repeat) + mel = mel.repeat(n_repeat, 1)[:self.target_length, :] + mel_fusion = torch.stack([mel, mel, mel], dim=0) + else: # if equal + mel_fusion = torch.stack([mel, mel, mel], dim=0) + mel_fusion = mel_fusion.transpose(1, 2) # [3, target_length, mel_bins] -> [3, mel_bins, target_length] + + # self.mean.append(mel_fusion.mean()) + # self.std.append(mel_fusion.std()) + mel_fusion = (mel_fusion - self.audio_mean) / (self.audio_std * 2) + return mel_fusion + + def get_mel(self, audio_data): + # mel shape: (n_mels, T) + audio_data -= audio_data.mean() + mel = torchaudio.compliance.kaldi.fbank( + audio_data, + htk_compat=True, + sample_frequency=self.sample_rate, + use_energy=False, + window_type="hanning", + num_mel_bins=self.num_mel_bins, + dither=0.0, + frame_length=25, + frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, + ) + return mel # (T, n_mels) + +def get_audio_transform(config): + config = config.vision_config + return AudioTransform(config) + + +def load_and_transform_audio( + audio_path, + transform, +): + waveform_and_sr = torchaudio_loader(audio_path) + audio_outputs = transform(waveform_and_sr) + + return audio_outputs + +class LanguageBindAudioProcessor(ProcessorMixin): + attributes = [] + tokenizer_class = ("LanguageBindAudioTokenizer") + + def __init__(self, config, tokenizer=None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.transform = get_audio_transform(config) + self.image_processor = load_and_transform_audio + self.tokenizer = tokenizer + + def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs): + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer(text, max_length=context_length, padding='max_length', + truncation=True, return_tensors=return_tensors, **kwargs) + + if images is not None: + images = make_list_of_images(images) + image_features = [self.image_processor(image, self.transform) for image in images] + image_features = torch.stack(image_features) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features + return encoding + elif text is not None: + return encoding + else: + return {"pixel_values": image_features} + + def batch_decode(self, skip_special_tokens=True, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) + + def decode(self, skip_special_tokens=True, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) diff --git a/languagebind/audio/tokenization_audio.py b/languagebind/audio/tokenization_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc40be3f96c20bf2581e23f8249f3cd5566ebe1 --- /dev/null +++ b/languagebind/audio/tokenization_audio.py @@ -0,0 +1,77 @@ +from transformers import CLIPTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "lb203/LanguageBind-Audio": "https://huggingface.co/lb203/LanguageBind-Audio/resolve/main/vocab.json", + }, + "merges_file": { + "lb203/LanguageBind-Audio": "https://huggingface.co/lb203/LanguageBind-Audio/resolve/main/merges.txt", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "lb203/LanguageBind-Audio": 77, +} + + +PRETRAINED_INIT_CONFIGURATION = { + "lb203/LanguageBind-Audio": {}, +} + +class LanguageBindAudioTokenizer(CLIPTokenizer): + """ + Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `<|endoftext|>`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `<|startoftext|>`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `<|endoftext|>`): + The end of sequence token. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token="<|startoftext|>", + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", # hack to enable padding + **kwargs, + ): + super(LanguageBindAudioTokenizer, self).__init__( + vocab_file, + merges_file, + errors, + unk_token, + bos_token, + eos_token, + pad_token, # hack to enable padding + **kwargs,) \ No newline at end of file diff --git a/languagebind/depth/configuration_depth.py b/languagebind/depth/configuration_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3901b2cf96635384c1e7d1e99845a66cd6c786 --- /dev/null +++ b/languagebind/depth/configuration_depth.py @@ -0,0 +1,425 @@ +import copy +import os +from typing import Union + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + + + + + + +class CLIPTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP + text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the text encoder of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`CLIPModel`]. + hidden_size (`int`, *optional*, defaults to 512): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 77): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import CLIPTextConfig, CLIPTextModel + + >>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPTextConfig() + + >>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "clip_text_model" + + def __init__( + self, + vocab_size=49408, + hidden_size=512, + intermediate_size=2048, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=8, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + # This differs from `CLIPTokenizer`'s default and from openai/clip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.add_time_attn = False ###################################### + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "clip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + + + +class CLIPVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a + CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import CLIPVisionConfig, CLIPVisionModel + + >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPVisionConfig() + + >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "clip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=32, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + + add_time_attn=False, ################################ + num_frames=1, ################################ + force_patch_dropout=0.0, ################################ + lora_r=2, ################################ + lora_alpha=16, ################################ + lora_dropout=0.0, ################################ + num_mel_bins=0.0, ################################ + target_length=0.0, ################################ + max_depth=10, + video_decode_backend='decord', ######################### + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + self.add_time_attn = add_time_attn ################ + self.num_frames = num_frames ################ + self.force_patch_dropout = force_patch_dropout ################ + self.lora_r = lora_r ################ + self.lora_alpha = lora_alpha ################ + self.lora_dropout = lora_dropout ################ + self.num_mel_bins = num_mel_bins ################ + self.target_length = target_length ################ + self.max_depth = max_depth ################ + self.video_decode_backend = video_decode_backend ################ + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "clip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class LanguageBindDepthConfig(PretrainedConfig): + r""" + [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate + a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating + a configuration with the defaults will yield a similar configuration to that of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`CLIPTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`CLIPVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimentionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import CLIPConfig, CLIPModel + + >>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPConfig() + + >>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig + >>> from transformers import CLIPTextConfig, CLIPVisionConfig + + >>> # Initializing a CLIPText and CLIPVision configuration + >>> config_text = CLIPTextConfig() + >>> config_vision = CLIPVisionConfig() + + >>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "LanguageBindDepth" + is_composition = True + + def __init__( + self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs + ): + # If `_config_dict` exist, we use them for the backward compatibility. + # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot + # of confusion!). + text_config_dict = kwargs.pop("text_config_dict", None) + vision_config_dict = kwargs.pop("vision_config_dict", None) + + super().__init__(**kwargs) + + # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in + # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most + # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`. + if text_config_dict is not None: + if text_config is None: + text_config = {} + + # This is the complete result when using `text_config_dict`. + _text_config_dict = CLIPTextConfig(**text_config_dict).to_dict() + + # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different. + for key, value in _text_config_dict.items(): + if key in text_config and value != text_config[key] and key not in ["transformers_version"]: + # If specified in `text_config_dict` + if key in text_config_dict: + message = ( + f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. " + f'The value `text_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The " + f'value `text_config["{key}"]` will be overriden.' + ) + logger.warning(message) + + # Update all values in `text_config` with the ones in `_text_config_dict`. + text_config.update(_text_config_dict) + + if vision_config_dict is not None: + if vision_config is None: + vision_config = {} + + # This is the complete result when using `vision_config_dict`. + _vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict() + # convert keys to string instead of integer + if "id2label" in _vision_config_dict: + _vision_config_dict["id2label"] = { + str(key): value for key, value in _vision_config_dict["id2label"].items() + } + + # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different. + for key, value in _vision_config_dict.items(): + if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]: + # If specified in `vision_config_dict` + if key in vision_config_dict: + message = ( + f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different " + f'values. The value `vision_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. " + f'The value `vision_config["{key}"]` will be overriden.' + ) + logger.warning(message) + + # Update all values in `vision_config` with the ones in `_vision_config_dict`. + vision_config.update(_vision_config_dict) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.") + + self.text_config = CLIPTextConfig(**text_config) + self.vision_config = CLIPVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs): + r""" + Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model + configuration. + + Returns: + [`CLIPConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["text_config"] = self.text_config.to_dict() + output["vision_config"] = self.vision_config.to_dict() + output["model_type"] = self.__class__.model_type + return output + + + + + + + + + + diff --git a/languagebind/depth/modeling_depth.py b/languagebind/depth/modeling_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..849eade79b0f4bff345b73bcf6a71115a28d0a09 --- /dev/null +++ b/languagebind/depth/modeling_depth.py @@ -0,0 +1,1030 @@ +import math +from typing import Optional, Tuple, Union + +import torch +from einops import rearrange +from peft import LoraConfig, get_peft_model +from torch import nn +from torch.nn import functional as F +from transformers import PreTrainedModel, add_start_docstrings +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings + +from .configuration_depth import LanguageBindDepthConfig, CLIPVisionConfig, CLIPTextConfig + + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x, B, T): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + if T == 1: + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + else: + rand = torch.randn(B, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1) + patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n') + + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + +class CLIPEncoderLayer(nn.Module): + def __init__(self, config: LanguageBindDepthConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + self.add_time_attn = config.add_time_attn + if self.add_time_attn: + self.t = config.num_frames + self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size)) + nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5) + + self.embed_dim = config.hidden_size + self.temporal_attn = CLIPAttention(config) + self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.temporal_mlp = CLIPMLP(config) + self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + + + if self.add_time_attn: + bt, n, d = hidden_states.shape + t = self.t + + # time embed + if t != 1: + n = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + hidden_states = hidden_states + self.temporal_embedding[:, :t, :] + hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # time attn + residual = hidden_states + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # hidden_states = self.layer_norm1(hidden_states) # share layernorm + hidden_states = self.temporal_layer_norm1(hidden_states) + hidden_states, attn_weights = self.temporal_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + residual = hidden_states + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # hidden_states = self.layer_norm2(hidden_states) # share layernorm + hidden_states = self.temporal_layer_norm2(hidden_states) + hidden_states = self.temporal_mlp(hidden_states) + hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # spatial attn + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + + + + + + + + +class CLIPPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LanguageBindDepthConfig + base_model_prefix = "clip" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, CLIPTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, CLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, CLIPAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, CLIPMLP): + factor = self.config.initializer_factor + in_proj_std = ( + (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + ) + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, LanguageBindDepth): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPVisionModelWithProjection): + nn.init.normal_( + module.visual_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPTextModelWithProjection): + nn.init.normal_( + module.text_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CLIPEncoder): + module.gradient_checkpointing = value + + +CLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, config: LanguageBindDepthConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + causal_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +class CLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = CLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class CLIPVisionTransformer(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + self.patch_dropout = PatchDropout(config.force_patch_dropout) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = CLIPEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + ###################################### + if len(pixel_values.shape) == 7: + b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape + # print(pixel_values.shape) + B = b_new * pair_new * bs_new + pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new) + + elif len(pixel_values.shape) == 5: + B, _, T, _, _ = pixel_values.shape + # print(pixel_values.shape) + pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w') + else: + # print(pixel_values.shape) + B, _, _, _ = pixel_values.shape + T = 1 + ########################### + hidden_states = self.embeddings(pixel_values) + + hidden_states = self.patch_dropout(hidden_states, B, T) ############################################## + + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################ + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The vision model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPVisionModel(CLIPPreTrainedModel): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: CLIPVisionConfig): + super().__init__(config) + self.vision_model = CLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPVisionModel + + >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(CLIP_START_DOCSTRING) +class LanguageBindDepth(CLIPPreTrainedModel): + config_class = LanguageBindDepthConfig + + def __init__(self, config: LanguageBindDepthConfig): + super().__init__(config) + + if not isinstance(config.text_config, CLIPTextConfig): + raise ValueError( + "config.text_config is expected to be of type CLIPTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, CLIPVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type CLIPVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + self.add_time_attn = vision_config.add_time_attn + self.lora_r = vision_config.lora_r + self.lora_alpha = vision_config.lora_alpha + self.lora_dropout = vision_config.lora_dropout + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer(text_config) + self.vision_model = CLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + self.convert_to_lora() + self.resize_pos(self.vision_model.embeddings, vision_config) + + def convert_to_lora(self): + if self.lora_r == 0: + return + if self.add_time_attn: + target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj", + "temporal_attn.q_proj", "temporal_attn.out_proj", + "temporal_mlp.fc1", "temporal_mlp.fc2"] + else: + target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"] + config = LoraConfig( + r=self.lora_r, # 16 + lora_alpha=self.lora_alpha, # 16 + target_modules=target_modules, # self_attn.out_proj + lora_dropout=self.lora_dropout, # 0.1 + bias="none", + modules_to_save=[], + ) + self.vision_model.encoder.is_gradient_checkpointing = False + self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config) + + def resize_pos(self, m, vision_config): + # convert embedding + if vision_config.num_mel_bins!=0 and vision_config.target_length!=0: + m.image_size = [vision_config.num_mel_bins, vision_config.target_length] + m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size + # pos resize + old_pos_embed_state_dict = m.position_embedding.state_dict() + old_pos_embed = old_pos_embed_state_dict['weight'] + dtype = old_pos_embed.dtype + grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size] + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + # m.to(args.device) + return + + m.num_patches = grid_size[0] * grid_size[1] + m.num_positions = m.num_patches + 1 + m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1))) + new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim) + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2 + + # if is_master(args): + # logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode='bicubic', + antialias=True, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype) + m.position_embedding = new_position_embedding + m.position_embedding.load_state_dict(old_pos_embed_state_dict) + + # m.to(args.device) + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`CLIPTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`CLIPVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindDepthConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CLIPOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = clip_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return CLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) \ No newline at end of file diff --git a/languagebind/depth/processing_depth.py b/languagebind/depth/processing_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..1019e0cb45c8be4bc7424c4d8f9d091dac5dab0b --- /dev/null +++ b/languagebind/depth/processing_depth.py @@ -0,0 +1,108 @@ +import cv2 +import torch +from PIL import Image +from torch import nn +from torchvision import transforms +from transformers import ProcessorMixin, BatchEncoding +from transformers.image_processing_utils import BatchFeature + +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) + +def make_list_of_images(x): + if not isinstance(x, list): + return [x] + return x + +def opencv_loader(path): + return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype('float32') + + +class DepthNorm(nn.Module): + def __init__( + self, + max_depth=0, + min_depth=0.01, + ): + super().__init__() + self.max_depth = max_depth + self.min_depth = min_depth + self.scale = 1000.0 # nyuv2 abs.depth + + def forward(self, image): + # image = np.array(image) + depth_img = image / self.scale # (H, W) in meters + depth_img = depth_img.clip(min=self.min_depth) + if self.max_depth != 0: + depth_img = depth_img.clip(max=self.max_depth) + depth_img /= self.max_depth # 0-1 + else: + depth_img /= depth_img.max() + depth_img = torch.from_numpy(depth_img).unsqueeze(0).repeat(3, 1, 1) # assume image + return depth_img.to(torch.get_default_dtype()) + +def get_depth_transform(config): + config = config.vision_config + transform = transforms.Compose( + [ + DepthNorm(max_depth=config.max_depth), + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD), # assume image + # transforms.Normalize((0.5, ), (0.5, )) # 0-1 to norm distribution + # transforms.Normalize((0.0418, ), (0.0295, )) # sun rgb-d imagebind + # transforms.Normalize((0.02, ), (0.00295, )) # nyuv2 + ] + ) + return transform + +def load_and_transform_depth(depth_path, transform): + depth = opencv_loader(depth_path) + depth_outputs = transform(depth) + return depth_outputs + +class LanguageBindDepthProcessor(ProcessorMixin): + attributes = [] + tokenizer_class = ("LanguageBindDepthTokenizer") + + def __init__(self, config, tokenizer=None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.transform = get_depth_transform(config) + self.image_processor = load_and_transform_depth + self.tokenizer = tokenizer + + def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs): + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer(text, max_length=context_length, padding='max_length', + truncation=True, return_tensors=return_tensors, **kwargs) + + if images is not None: + images = make_list_of_images(images) + image_features = [self.image_processor(image, self.transform) for image in images] + image_features = torch.stack(image_features) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features + return encoding + elif text is not None: + return encoding + else: + return {"pixel_values": image_features} + + def batch_decode(self, skip_special_tokens=True, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) + + def decode(self, skip_special_tokens=True, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) diff --git a/languagebind/depth/tokenization_depth.py b/languagebind/depth/tokenization_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..eda9905131c2240cddf982b2937fe96cb33b4053 --- /dev/null +++ b/languagebind/depth/tokenization_depth.py @@ -0,0 +1,77 @@ +from transformers import CLIPTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "lb203/LanguageBind-Depth": "https://huggingface.co/lb203/LanguageBind-Depth/resolve/main/vocab.json", + }, + "merges_file": { + "lb203/LanguageBind-Depth": "https://huggingface.co/lb203/LanguageBind-Depth/resolve/main/merges.txt", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "lb203/LanguageBind-Depth": 77, +} + + +PRETRAINED_INIT_CONFIGURATION = { + "lb203/LanguageBind-Thermal": {}, +} + +class LanguageBindDepthTokenizer(CLIPTokenizer): + """ + Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `<|endoftext|>`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `<|startoftext|>`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `<|endoftext|>`): + The end of sequence token. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token="<|startoftext|>", + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", # hack to enable padding + **kwargs, + ): + super(LanguageBindDepthTokenizer, self).__init__( + vocab_file, + merges_file, + errors, + unk_token, + bos_token, + eos_token, + pad_token, # hack to enable padding + **kwargs,) \ No newline at end of file diff --git a/languagebind/image/configuration_image.py b/languagebind/image/configuration_image.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c7b0f7aad10f791c89b2f89aa4161defb990ae --- /dev/null +++ b/languagebind/image/configuration_image.py @@ -0,0 +1,423 @@ +import copy +import os +from typing import Union + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + + + + + + +class CLIPTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP + text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the text encoder of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`CLIPModel`]. + hidden_size (`int`, *optional*, defaults to 512): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 77): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import CLIPTextConfig, CLIPTextModel + + >>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPTextConfig() + + >>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "clip_text_model" + + def __init__( + self, + vocab_size=49408, + hidden_size=512, + intermediate_size=2048, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=8, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + # This differs from `CLIPTokenizer`'s default and from openai/clip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.add_time_attn = False ###################################### + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "clip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + + + +class CLIPVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a + CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import CLIPVisionConfig, CLIPVisionModel + + >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPVisionConfig() + + >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "clip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=32, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + + add_time_attn=False, ################################ + num_frames=1, ################################ + force_patch_dropout=0.0, ################################ + lora_r=2, ################################ + lora_alpha=16, ################################ + lora_dropout=0.0, ################################ + num_mel_bins=0.0, ################################ + target_length=0.0, ################################ + video_decode_backend='decord', ######################### + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + self.add_time_attn = add_time_attn ################ + self.num_frames = num_frames ################ + self.force_patch_dropout = force_patch_dropout ################ + self.lora_r = lora_r ################ + self.lora_alpha = lora_alpha ################ + self.lora_dropout = lora_dropout ################ + self.num_mel_bins = num_mel_bins ################ + self.target_length = target_length ################ + self.video_decode_backend = video_decode_backend ################ + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "clip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class LanguageBindImageConfig(PretrainedConfig): + r""" + [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate + a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating + a configuration with the defaults will yield a similar configuration to that of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`CLIPTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`CLIPVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimentionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import CLIPConfig, CLIPModel + + >>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPConfig() + + >>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig + >>> from transformers import CLIPTextConfig, CLIPVisionConfig + + >>> # Initializing a CLIPText and CLIPVision configuration + >>> config_text = CLIPTextConfig() + >>> config_vision = CLIPVisionConfig() + + >>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "LanguageBindImage" + is_composition = True + + def __init__( + self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs + ): + # If `_config_dict` exist, we use them for the backward compatibility. + # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot + # of confusion!). + text_config_dict = kwargs.pop("text_config_dict", None) + vision_config_dict = kwargs.pop("vision_config_dict", None) + + super().__init__(**kwargs) + + # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in + # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most + # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`. + if text_config_dict is not None: + if text_config is None: + text_config = {} + + # This is the complete result when using `text_config_dict`. + _text_config_dict = CLIPTextConfig(**text_config_dict).to_dict() + + # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different. + for key, value in _text_config_dict.items(): + if key in text_config and value != text_config[key] and key not in ["transformers_version"]: + # If specified in `text_config_dict` + if key in text_config_dict: + message = ( + f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. " + f'The value `text_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The " + f'value `text_config["{key}"]` will be overriden.' + ) + logger.warning(message) + + # Update all values in `text_config` with the ones in `_text_config_dict`. + text_config.update(_text_config_dict) + + if vision_config_dict is not None: + if vision_config is None: + vision_config = {} + + # This is the complete result when using `vision_config_dict`. + _vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict() + # convert keys to string instead of integer + if "id2label" in _vision_config_dict: + _vision_config_dict["id2label"] = { + str(key): value for key, value in _vision_config_dict["id2label"].items() + } + + # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different. + for key, value in _vision_config_dict.items(): + if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]: + # If specified in `vision_config_dict` + if key in vision_config_dict: + message = ( + f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different " + f'values. The value `vision_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. " + f'The value `vision_config["{key}"]` will be overriden.' + ) + logger.warning(message) + + # Update all values in `vision_config` with the ones in `_vision_config_dict`. + vision_config.update(_vision_config_dict) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.") + + self.text_config = CLIPTextConfig(**text_config) + self.vision_config = CLIPVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs): + r""" + Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model + configuration. + + Returns: + [`CLIPConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["text_config"] = self.text_config.to_dict() + output["vision_config"] = self.vision_config.to_dict() + output["model_type"] = self.__class__.model_type + return output + + + + + + + + + + diff --git a/languagebind/image/modeling_image.py b/languagebind/image/modeling_image.py new file mode 100644 index 0000000000000000000000000000000000000000..7228f5daed51a2f2b0c94d9fd68076eff1a39ae1 --- /dev/null +++ b/languagebind/image/modeling_image.py @@ -0,0 +1,1030 @@ +import math +from typing import Optional, Tuple, Union + +import torch +from einops import rearrange +from peft import LoraConfig, get_peft_model +from torch import nn +from torch.nn import functional as F +from transformers import PreTrainedModel, add_start_docstrings +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings + +from .configuration_image import LanguageBindImageConfig, CLIPVisionConfig, CLIPTextConfig + + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x, B, T): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + if T == 1: + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + else: + rand = torch.randn(B, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1) + patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n') + + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + +class CLIPEncoderLayer(nn.Module): + def __init__(self, config: LanguageBindImageConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + self.add_time_attn = config.add_time_attn + if self.add_time_attn: + self.t = config.num_frames + self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size)) + nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5) + + self.embed_dim = config.hidden_size + self.temporal_attn = CLIPAttention(config) + self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.temporal_mlp = CLIPMLP(config) + self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + + + if self.add_time_attn: + bt, n, d = hidden_states.shape + t = self.t + + # time embed + if t != 1: + n = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + hidden_states = hidden_states + self.temporal_embedding[:, :t, :] + hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # time attn + residual = hidden_states + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # hidden_states = self.layer_norm1(hidden_states) # share layernorm + hidden_states = self.temporal_layer_norm1(hidden_states) + hidden_states, attn_weights = self.temporal_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + residual = hidden_states + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # hidden_states = self.layer_norm2(hidden_states) # share layernorm + hidden_states = self.temporal_layer_norm2(hidden_states) + hidden_states = self.temporal_mlp(hidden_states) + hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # spatial attn + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + + + + + + + + +class CLIPPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LanguageBindImageConfig + base_model_prefix = "clip" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, CLIPTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, CLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, CLIPAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, CLIPMLP): + factor = self.config.initializer_factor + in_proj_std = ( + (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + ) + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, LanguageBindImage): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPVisionModelWithProjection): + nn.init.normal_( + module.visual_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPTextModelWithProjection): + nn.init.normal_( + module.text_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CLIPEncoder): + module.gradient_checkpointing = value + + +CLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, config: LanguageBindImageConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + causal_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +class CLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = CLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class CLIPVisionTransformer(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + self.patch_dropout = PatchDropout(config.force_patch_dropout) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = CLIPEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + ###################################### + if len(pixel_values.shape) == 7: + b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape + # print(pixel_values.shape) + B = b_new * pair_new * bs_new + pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new) + + elif len(pixel_values.shape) == 5: + B, _, T, _, _ = pixel_values.shape + # print(pixel_values.shape) + pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w') + else: + # print(pixel_values.shape) + B, _, _, _ = pixel_values.shape + T = 1 + ########################### + hidden_states = self.embeddings(pixel_values) + + hidden_states = self.patch_dropout(hidden_states, B, T) ############################################## + + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################ + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The vision model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPVisionModel(CLIPPreTrainedModel): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: CLIPVisionConfig): + super().__init__(config) + self.vision_model = CLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPVisionModel + + >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(CLIP_START_DOCSTRING) +class LanguageBindImage(CLIPPreTrainedModel): + config_class = LanguageBindImageConfig + + def __init__(self, config: LanguageBindImageConfig): + super().__init__(config) + + if not isinstance(config.text_config, CLIPTextConfig): + raise ValueError( + "config.text_config is expected to be of type CLIPTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, CLIPVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type CLIPVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + self.add_time_attn = vision_config.add_time_attn + self.lora_r = vision_config.lora_r + self.lora_alpha = vision_config.lora_alpha + self.lora_dropout = vision_config.lora_dropout + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer(text_config) + self.vision_model = CLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + self.convert_to_lora() + self.resize_pos(self.vision_model.embeddings, vision_config) + + def convert_to_lora(self): + if self.lora_r == 0: + return + if self.add_time_attn: + target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj", + "temporal_attn.q_proj", "temporal_attn.out_proj", + "temporal_mlp.fc1", "temporal_mlp.fc2"] + else: + target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"] + config = LoraConfig( + r=self.lora_r, # 16 + lora_alpha=self.lora_alpha, # 16 + target_modules=target_modules, # self_attn.out_proj + lora_dropout=self.lora_dropout, # 0.1 + bias="none", + modules_to_save=[], + ) + self.vision_model.encoder.is_gradient_checkpointing = False + self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config) + + def resize_pos(self, m, vision_config): + # convert embedding + if vision_config.num_mel_bins!=0 and vision_config.target_length!=0: + m.image_size = [vision_config.num_mel_bins, vision_config.target_length] + m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size + # pos resize + old_pos_embed_state_dict = m.position_embedding.state_dict() + old_pos_embed = old_pos_embed_state_dict['weight'] + dtype = old_pos_embed.dtype + grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size] + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + # m.to(args.device) + return + + m.num_patches = grid_size[0] * grid_size[1] + m.num_positions = m.num_patches + 1 + m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1))) + new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim) + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2 + + # if is_master(args): + # logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode='bicubic', + antialias=True, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype) + m.position_embedding = new_position_embedding + m.position_embedding.load_state_dict(old_pos_embed_state_dict) + + # m.to(args.device) + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`CLIPTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`CLIPVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindImageConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CLIPOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = clip_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return CLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) \ No newline at end of file diff --git a/languagebind/image/processing_image.py b/languagebind/image/processing_image.py new file mode 100644 index 0000000000000000000000000000000000000000..b61db106a39d2c44f1a8c1709afe08b4231c4c64 --- /dev/null +++ b/languagebind/image/processing_image.py @@ -0,0 +1,77 @@ +import torch +from PIL import Image +from torchvision import transforms +from transformers import ProcessorMixin, BatchEncoding +from transformers.image_processing_utils import BatchFeature + +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) + +def make_list_of_images(x): + if not isinstance(x, list): + return [x] + return x + +def get_image_transform(config): + config = config.vision_config + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD) # assume image + ] + ) + return transform + + +def load_and_transform_image(image_path, transform): + image = Image.open(image_path) + image_outputs = transform(image) + return image_outputs + +class LanguageBindImageProcessor(ProcessorMixin): + attributes = [] + tokenizer_class = ("LanguageBindImageTokenizer") + + def __init__(self, config, tokenizer=None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.transform = get_image_transform(config) + self.image_processor = load_and_transform_image + self.tokenizer = tokenizer + + def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs): + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer(text, max_length=context_length, padding='max_length', + truncation=True, return_tensors=return_tensors, **kwargs) + + if images is not None: + images = make_list_of_images(images) + image_features = [self.image_processor(image, self.transform) for image in images] + image_features = torch.stack(image_features) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features + return encoding + elif text is not None: + return encoding + else: + return {"pixel_values": image_features} + + def batch_decode(self, skip_special_tokens=True, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) + + def decode(self, skip_special_tokens=True, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) diff --git a/languagebind/image/tokenization_image.py b/languagebind/image/tokenization_image.py new file mode 100644 index 0000000000000000000000000000000000000000..593423d089100b3d61957f658cca04b541336f65 --- /dev/null +++ b/languagebind/image/tokenization_image.py @@ -0,0 +1,77 @@ +from transformers import CLIPTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "lb203/LanguageBind-Image": "https://huggingface.co/lb203/LanguageBind-Image/resolve/main/vocab.json", + }, + "merges_file": { + "lb203/LanguageBind-Image": "https://huggingface.co/lb203/LanguageBind-Image/resolve/main/merges.txt", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "lb203/LanguageBind-Image": 77, +} + + +PRETRAINED_INIT_CONFIGURATION = { + "lb203/LanguageBind-Image": {}, +} + +class LanguageBindImageTokenizer(CLIPTokenizer): + """ + Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `<|endoftext|>`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `<|startoftext|>`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `<|endoftext|>`): + The end of sequence token. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token="<|startoftext|>", + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", # hack to enable padding + **kwargs, + ): + super(LanguageBindImageTokenizer, self).__init__( + vocab_file, + merges_file, + errors, + unk_token, + bos_token, + eos_token, + pad_token, # hack to enable padding + **kwargs,) \ No newline at end of file diff --git a/languagebind/thermal/configuration_thermal.py b/languagebind/thermal/configuration_thermal.py new file mode 100644 index 0000000000000000000000000000000000000000..fd6cedd5d44c248b32e89f51d5c28595bffcbefc --- /dev/null +++ b/languagebind/thermal/configuration_thermal.py @@ -0,0 +1,423 @@ +import copy +import os +from typing import Union + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + + + + + + +class CLIPTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP + text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the text encoder of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`CLIPModel`]. + hidden_size (`int`, *optional*, defaults to 512): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 77): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import CLIPTextConfig, CLIPTextModel + + >>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPTextConfig() + + >>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "clip_text_model" + + def __init__( + self, + vocab_size=49408, + hidden_size=512, + intermediate_size=2048, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=8, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + # This differs from `CLIPTokenizer`'s default and from openai/clip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.add_time_attn = False ###################################### + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "clip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + + + +class CLIPVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a + CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import CLIPVisionConfig, CLIPVisionModel + + >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPVisionConfig() + + >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "clip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=32, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + + add_time_attn=False, ################################ + num_frames=1, ################################ + force_patch_dropout=0.0, ################################ + lora_r=2, ################################ + lora_alpha=16, ################################ + lora_dropout=0.0, ################################ + num_mel_bins=0.0, ################################ + target_length=0.0, ################################ + video_decode_backend='decord', ######################### + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + self.add_time_attn = add_time_attn ################ + self.num_frames = num_frames ################ + self.force_patch_dropout = force_patch_dropout ################ + self.lora_r = lora_r ################ + self.lora_alpha = lora_alpha ################ + self.lora_dropout = lora_dropout ################ + self.num_mel_bins = num_mel_bins ################ + self.target_length = target_length ################ + self.video_decode_backend = video_decode_backend ################ + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "clip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class LanguageBindThermalConfig(PretrainedConfig): + r""" + [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate + a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating + a configuration with the defaults will yield a similar configuration to that of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`CLIPTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`CLIPVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimentionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import CLIPConfig, CLIPModel + + >>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPConfig() + + >>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig + >>> from transformers import CLIPTextConfig, CLIPVisionConfig + + >>> # Initializing a CLIPText and CLIPVision configuration + >>> config_text = CLIPTextConfig() + >>> config_vision = CLIPVisionConfig() + + >>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "LanguageBindThermal" + is_composition = True + + def __init__( + self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs + ): + # If `_config_dict` exist, we use them for the backward compatibility. + # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot + # of confusion!). + text_config_dict = kwargs.pop("text_config_dict", None) + vision_config_dict = kwargs.pop("vision_config_dict", None) + + super().__init__(**kwargs) + + # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in + # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most + # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`. + if text_config_dict is not None: + if text_config is None: + text_config = {} + + # This is the complete result when using `text_config_dict`. + _text_config_dict = CLIPTextConfig(**text_config_dict).to_dict() + + # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different. + for key, value in _text_config_dict.items(): + if key in text_config and value != text_config[key] and key not in ["transformers_version"]: + # If specified in `text_config_dict` + if key in text_config_dict: + message = ( + f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. " + f'The value `text_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The " + f'value `text_config["{key}"]` will be overriden.' + ) + logger.warning(message) + + # Update all values in `text_config` with the ones in `_text_config_dict`. + text_config.update(_text_config_dict) + + if vision_config_dict is not None: + if vision_config is None: + vision_config = {} + + # This is the complete result when using `vision_config_dict`. + _vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict() + # convert keys to string instead of integer + if "id2label" in _vision_config_dict: + _vision_config_dict["id2label"] = { + str(key): value for key, value in _vision_config_dict["id2label"].items() + } + + # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different. + for key, value in _vision_config_dict.items(): + if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]: + # If specified in `vision_config_dict` + if key in vision_config_dict: + message = ( + f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different " + f'values. The value `vision_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. " + f'The value `vision_config["{key}"]` will be overriden.' + ) + logger.warning(message) + + # Update all values in `vision_config` with the ones in `_vision_config_dict`. + vision_config.update(_vision_config_dict) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.") + + self.text_config = CLIPTextConfig(**text_config) + self.vision_config = CLIPVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs): + r""" + Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model + configuration. + + Returns: + [`CLIPConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["text_config"] = self.text_config.to_dict() + output["vision_config"] = self.vision_config.to_dict() + output["model_type"] = self.__class__.model_type + return output + + + + + + + + + + diff --git a/languagebind/thermal/modeling_thermal.py b/languagebind/thermal/modeling_thermal.py new file mode 100644 index 0000000000000000000000000000000000000000..f0323b3351a4eed0165a8b7a1e8cc610ea0669ca --- /dev/null +++ b/languagebind/thermal/modeling_thermal.py @@ -0,0 +1,1030 @@ +import math +from typing import Optional, Tuple, Union + +import torch +from einops import rearrange +from peft import LoraConfig, get_peft_model +from torch import nn +from torch.nn import functional as F +from transformers import PreTrainedModel, add_start_docstrings +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings + +from .configuration_thermal import LanguageBindThermalConfig, CLIPVisionConfig, CLIPTextConfig + + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x, B, T): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + if T == 1: + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + else: + rand = torch.randn(B, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1) + patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n') + + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + +class CLIPEncoderLayer(nn.Module): + def __init__(self, config: LanguageBindThermalConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + self.add_time_attn = config.add_time_attn + if self.add_time_attn: + self.t = config.num_frames + self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size)) + nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5) + + self.embed_dim = config.hidden_size + self.temporal_attn = CLIPAttention(config) + self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.temporal_mlp = CLIPMLP(config) + self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + + + if self.add_time_attn: + bt, n, d = hidden_states.shape + t = self.t + + # time embed + if t != 1: + n = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + hidden_states = hidden_states + self.temporal_embedding[:, :t, :] + hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # time attn + residual = hidden_states + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # hidden_states = self.layer_norm1(hidden_states) # share layernorm + hidden_states = self.temporal_layer_norm1(hidden_states) + hidden_states, attn_weights = self.temporal_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + residual = hidden_states + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # hidden_states = self.layer_norm2(hidden_states) # share layernorm + hidden_states = self.temporal_layer_norm2(hidden_states) + hidden_states = self.temporal_mlp(hidden_states) + hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # spatial attn + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + + + + + + + + +class CLIPPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LanguageBindThermalConfig + base_model_prefix = "clip" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, CLIPTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, CLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, CLIPAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, CLIPMLP): + factor = self.config.initializer_factor + in_proj_std = ( + (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + ) + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, LanguageBindThermal): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPVisionModelWithProjection): + nn.init.normal_( + module.visual_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPTextModelWithProjection): + nn.init.normal_( + module.text_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CLIPEncoder): + module.gradient_checkpointing = value + + +CLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, config: LanguageBindThermalConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + causal_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +class CLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = CLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class CLIPVisionTransformer(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + self.patch_dropout = PatchDropout(config.force_patch_dropout) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = CLIPEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + ###################################### + if len(pixel_values.shape) == 7: + b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape + # print(pixel_values.shape) + B = b_new * pair_new * bs_new + pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new) + + elif len(pixel_values.shape) == 5: + B, _, T, _, _ = pixel_values.shape + # print(pixel_values.shape) + pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w') + else: + # print(pixel_values.shape) + B, _, _, _ = pixel_values.shape + T = 1 + ########################### + hidden_states = self.embeddings(pixel_values) + + hidden_states = self.patch_dropout(hidden_states, B, T) ############################################## + + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################ + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The vision model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPVisionModel(CLIPPreTrainedModel): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: CLIPVisionConfig): + super().__init__(config) + self.vision_model = CLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPVisionModel + + >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(CLIP_START_DOCSTRING) +class LanguageBindThermal(CLIPPreTrainedModel): + config_class = LanguageBindThermalConfig + + def __init__(self, config: LanguageBindThermalConfig): + super().__init__(config) + + if not isinstance(config.text_config, CLIPTextConfig): + raise ValueError( + "config.text_config is expected to be of type CLIPTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, CLIPVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type CLIPVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + self.add_time_attn = vision_config.add_time_attn + self.lora_r = vision_config.lora_r + self.lora_alpha = vision_config.lora_alpha + self.lora_dropout = vision_config.lora_dropout + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer(text_config) + self.vision_model = CLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + self.convert_to_lora() + self.resize_pos(self.vision_model.embeddings, vision_config) + + def convert_to_lora(self): + if self.lora_r == 0: + return + if self.add_time_attn: + target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj", + "temporal_attn.q_proj", "temporal_attn.out_proj", + "temporal_mlp.fc1", "temporal_mlp.fc2"] + else: + target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"] + config = LoraConfig( + r=self.lora_r, # 16 + lora_alpha=self.lora_alpha, # 16 + target_modules=target_modules, # self_attn.out_proj + lora_dropout=self.lora_dropout, # 0.1 + bias="none", + modules_to_save=[], + ) + self.vision_model.encoder.is_gradient_checkpointing = False + self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config) + + def resize_pos(self, m, vision_config): + # convert embedding + if vision_config.num_mel_bins!=0 and vision_config.target_length!=0: + m.image_size = [vision_config.num_mel_bins, vision_config.target_length] + m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size + # pos resize + old_pos_embed_state_dict = m.position_embedding.state_dict() + old_pos_embed = old_pos_embed_state_dict['weight'] + dtype = old_pos_embed.dtype + grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size] + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + # m.to(args.device) + return + + m.num_patches = grid_size[0] * grid_size[1] + m.num_positions = m.num_patches + 1 + m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1))) + new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim) + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2 + + # if is_master(args): + # logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode='bicubic', + antialias=True, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype) + m.position_embedding = new_position_embedding + m.position_embedding.load_state_dict(old_pos_embed_state_dict) + + # m.to(args.device) + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`CLIPTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`CLIPVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindThermalConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CLIPOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = clip_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return CLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) \ No newline at end of file diff --git a/languagebind/thermal/processing_thermal.py b/languagebind/thermal/processing_thermal.py new file mode 100644 index 0000000000000000000000000000000000000000..36ed1f09d3bf23514baf4859e462d28bc49dfd53 --- /dev/null +++ b/languagebind/thermal/processing_thermal.py @@ -0,0 +1,77 @@ +import torch +from PIL import Image +from torchvision import transforms +from transformers import ProcessorMixin, BatchEncoding +from transformers.image_processing_utils import BatchFeature + +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) + +def make_list_of_images(x): + if not isinstance(x, list): + return [x] + return x + +def get_thermal_transform(config): + config = config.vision_config + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD) # assume image + ] + ) + return transform + + +def load_and_transform_thermal(thermal_path, transform): + thermal = Image.open(thermal_path) + thermal_outputs = transform(thermal) + return thermal_outputs + +class LanguageBindThermalProcessor(ProcessorMixin): + attributes = [] + tokenizer_class = ("LanguageBindThermalTokenizer") + + def __init__(self, config, tokenizer=None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.transform = get_thermal_transform(config) + self.image_processor = load_and_transform_thermal + self.tokenizer = tokenizer + + def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs): + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer(text, max_length=context_length, padding='max_length', + truncation=True, return_tensors=return_tensors, **kwargs) + + if images is not None: + images = make_list_of_images(images) + image_features = [self.image_processor(image, self.transform) for image in images] + image_features = torch.stack(image_features) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features + return encoding + elif text is not None: + return encoding + else: + return {"pixel_values": image_features} + + def batch_decode(self, skip_special_tokens=True, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) + + def decode(self, skip_special_tokens=True, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) diff --git a/languagebind/thermal/tokenization_thermal.py b/languagebind/thermal/tokenization_thermal.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ebb5607bc8f2a24341a7b11f22663e760012dd --- /dev/null +++ b/languagebind/thermal/tokenization_thermal.py @@ -0,0 +1,77 @@ +from transformers import CLIPTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "lb203/LanguageBind-Thermal": "https://huggingface.co/lb203/LanguageBind-Thermal/resolve/main/vocab.json", + }, + "merges_file": { + "lb203/LanguageBind-Thermal": "https://huggingface.co/lb203/LanguageBind-Thermal/resolve/main/merges.txt", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "lb203/LanguageBind-Thermal": 77, +} + + +PRETRAINED_INIT_CONFIGURATION = { + "lb203/LanguageBind-Thermal": {}, +} + +class LanguageBindThermalTokenizer(CLIPTokenizer): + """ + Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `<|endoftext|>`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `<|startoftext|>`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `<|endoftext|>`): + The end of sequence token. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token="<|startoftext|>", + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", # hack to enable padding + **kwargs, + ): + super(LanguageBindThermalTokenizer, self).__init__( + vocab_file, + merges_file, + errors, + unk_token, + bos_token, + eos_token, + pad_token, # hack to enable padding + **kwargs,) \ No newline at end of file diff --git a/languagebind/video/configuration_video.py b/languagebind/video/configuration_video.py new file mode 100644 index 0000000000000000000000000000000000000000..4b108ec51799ae0d77432ffa85690e1a1858e60c --- /dev/null +++ b/languagebind/video/configuration_video.py @@ -0,0 +1,423 @@ +import copy +import os +from typing import Union + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + + + + + + +class CLIPTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CLIPTextModel`]. It is used to instantiate a CLIP + text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the text encoder of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`CLIPModel`]. + hidden_size (`int`, *optional*, defaults to 512): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 77): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import CLIPTextConfig, CLIPTextModel + + >>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPTextConfig() + + >>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "clip_text_model" + + def __init__( + self, + vocab_size=49408, + hidden_size=512, + intermediate_size=2048, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=8, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + # This differs from `CLIPTokenizer`'s default and from openai/clip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.add_time_attn = False ###################################### + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "clip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + + + +class CLIPVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CLIPVisionModel`]. It is used to instantiate a + CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import CLIPVisionConfig, CLIPVisionModel + + >>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPVisionConfig() + + >>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "clip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + projection_dim=512, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=32, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + + add_time_attn=False, ################################ + num_frames=1, ################################ + force_patch_dropout=0.0, ################################ + lora_r=2, ################################ + lora_alpha=16, ################################ + lora_dropout=0.0, ################################ + num_mel_bins=0.0, ################################ + target_length=0.0, ################################ + video_decode_backend='decord', ######################### + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + self.add_time_attn = add_time_attn ################ + self.num_frames = num_frames ################ + self.force_patch_dropout = force_patch_dropout ################ + self.lora_r = lora_r ################ + self.lora_alpha = lora_alpha ################ + self.lora_dropout = lora_dropout ################ + self.num_mel_bins = num_mel_bins ################ + self.target_length = target_length ################ + self.video_decode_backend = video_decode_backend ################ + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "clip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class LanguageBindVideoConfig(PretrainedConfig): + r""" + [`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate + a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating + a configuration with the defaults will yield a similar configuration to that of the CLIP + [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`CLIPTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`CLIPVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimentionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import CLIPConfig, CLIPModel + + >>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration + >>> configuration = CLIPConfig() + + >>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration + >>> model = CLIPModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig + >>> from transformers import CLIPTextConfig, CLIPVisionConfig + + >>> # Initializing a CLIPText and CLIPVision configuration + >>> config_text = CLIPTextConfig() + >>> config_vision = CLIPVisionConfig() + + >>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "LanguageBindVideo" + is_composition = True + + def __init__( + self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs + ): + # If `_config_dict` exist, we use them for the backward compatibility. + # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot + # of confusion!). + text_config_dict = kwargs.pop("text_config_dict", None) + vision_config_dict = kwargs.pop("vision_config_dict", None) + + super().__init__(**kwargs) + + # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in + # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most + # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`. + if text_config_dict is not None: + if text_config is None: + text_config = {} + + # This is the complete result when using `text_config_dict`. + _text_config_dict = CLIPTextConfig(**text_config_dict).to_dict() + + # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different. + for key, value in _text_config_dict.items(): + if key in text_config and value != text_config[key] and key not in ["transformers_version"]: + # If specified in `text_config_dict` + if key in text_config_dict: + message = ( + f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. " + f'The value `text_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The " + f'value `text_config["{key}"]` will be overriden.' + ) + logger.warning(message) + + # Update all values in `text_config` with the ones in `_text_config_dict`. + text_config.update(_text_config_dict) + + if vision_config_dict is not None: + if vision_config is None: + vision_config = {} + + # This is the complete result when using `vision_config_dict`. + _vision_config_dict = CLIPVisionConfig(**vision_config_dict).to_dict() + # convert keys to string instead of integer + if "id2label" in _vision_config_dict: + _vision_config_dict["id2label"] = { + str(key): value for key, value in _vision_config_dict["id2label"].items() + } + + # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different. + for key, value in _vision_config_dict.items(): + if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]: + # If specified in `vision_config_dict` + if key in vision_config_dict: + message = ( + f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different " + f'values. The value `vision_config_dict["{key}"]` will be used instead.' + ) + # If inferred from default argument values (just to be super careful) + else: + message = ( + f"`vision_config_dict` is provided which will be used to initialize `CLIPVisionConfig`. " + f'The value `vision_config["{key}"]` will be overriden.' + ) + logger.warning(message) + + # Update all values in `vision_config` with the ones in `_vision_config_dict`. + vision_config.update(_vision_config_dict) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.") + + self.text_config = CLIPTextConfig(**text_config) + self.vision_config = CLIPVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs): + r""" + Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model + configuration. + + Returns: + [`CLIPConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["text_config"] = self.text_config.to_dict() + output["vision_config"] = self.vision_config.to_dict() + output["model_type"] = self.__class__.model_type + return output + + + + + + + + + + diff --git a/languagebind/video/modeling_video.py b/languagebind/video/modeling_video.py new file mode 100644 index 0000000000000000000000000000000000000000..d4a1c33be995ca94a0d533fee50999b55bdd6c0b --- /dev/null +++ b/languagebind/video/modeling_video.py @@ -0,0 +1,1142 @@ +import math +from typing import Optional, Tuple, Union + +import torch +from einops import rearrange +from peft import LoraConfig, get_peft_model +from torch import nn +from torch.nn import functional as F +from transformers import PreTrainedModel, add_start_docstrings +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \ + CLIPVisionModelWithProjection, CLIPTextModelWithProjection, _expand_mask, CLIPOutput, clip_loss +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings + +from .configuration_video import LanguageBindVideoConfig, CLIPVisionConfig, CLIPTextConfig + + + +class CLIPVisionEmbeddings(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + # (b t) c h w + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) # b hw c + return embeddings + +class CLIPVisionEmbeddings3D(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.num_frames = config.num_frames + self.tube_size = getattr(config, 'tube_size', 1) + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) + + self.expand3d() + + def expand3d(self): + + state_dict = self.patch_embedding.state_dict() + state_dict_expand = state_dict['weight'].unsqueeze(2) + device, dtype = state_dict_expand.device, state_dict_expand.dtype + # print(device, dtype) + + zero = torch.zeros_like(state_dict_expand).to(device=device, dtype=dtype) + state_dict_expand3d = torch.cat([state_dict_expand] + (self.tube_size-1)*[zero], dim=2) + + # state_dict_expand3d = torch.cat([state_dict_expand / self.tube_size] * self.tube_size, dim=2) + + patch_embedding = nn.Conv3d( + in_channels=self.patch_embedding.in_channels, + out_channels=self.embed_dim, + kernel_size=(self.tube_size, self.patch_size, self.patch_size), + stride=(self.tube_size, self.patch_size, self.patch_size), + bias=False, + ).to(device=device, dtype=dtype) + patch_embedding.load_state_dict({'weight': state_dict_expand3d}) + self.patch_embedding = patch_embedding + + + class_embedding = nn.Parameter(self.class_embedding.data.repeat(self.num_frames // self.tube_size, 1)).to(device=device, dtype=dtype) + self.class_embedding = class_embedding + + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + # (b t) c h w + batch_size = pixel_values.shape[0] // self.num_frames + pixel_values = rearrange(pixel_values, '(b t) c h w -> b c t h w', b=batch_size, t=self.num_frames) + # print('pixel_values', pixel_values.shape) + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, t, grid, grid] + # print('patch_embeds', patch_embeds.shape) + # SET_GLOBAL_VALUE('NUM_FRAMES', patch_embeds.shape[2]) + patch_embeds = rearrange(patch_embeds, 'b c t h w -> b t (h w) c') + + class_embeds = self.class_embedding.unsqueeze(1).unsqueeze(0).repeat(batch_size, 1, 1, 1) # b t 1 c + # print('class_embeds', class_embeds.device, class_embeds.dtype) + # print('patch_embeds', patch_embeds.device, patch_embeds.dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=2) # b t hw+1 c + embeddings = embeddings + self.position_embedding(self.position_ids) + embeddings = rearrange(embeddings, 'b t hw_1 c -> (b t) hw_1 c') + return embeddings + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x, B, T): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + if T == 1: + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + else: + rand = torch.randn(B, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1) + patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n') + + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + +class CLIPEncoderLayer(nn.Module): + def __init__(self, config: LanguageBindVideoConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + self.add_time_attn = config.add_time_attn + if self.add_time_attn: + self.t = config.num_frames + self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size)) + nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5) + + self.embed_dim = config.hidden_size + self.temporal_attn = CLIPAttention(config) + self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + # self.temporal_mlp = CLIPMLP(config) + # self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + + + if self.add_time_attn: + bt, n, d = hidden_states.shape + t = self.t + + # time embed + if t != 1: + n = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + hidden_states = hidden_states + self.temporal_embedding[:, :t, :] + hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # time attn + residual = hidden_states + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # hidden_states = self.layer_norm1(hidden_states) # share layernorm + hidden_states = self.temporal_layer_norm1(hidden_states) + hidden_states, attn_weights = self.temporal_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # residual = hidden_states + # hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # # hidden_states = self.layer_norm2(hidden_states) # share layernorm + # hidden_states = self.temporal_layer_norm2(hidden_states) + # hidden_states = self.temporal_mlp(hidden_states) + # hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # spatial attn + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + + + + + + + + +class CLIPPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LanguageBindVideoConfig + base_model_prefix = "clip" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, CLIPTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, CLIPVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, CLIPAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, CLIPMLP): + factor = self.config.initializer_factor + in_proj_std = ( + (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + ) + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, LanguageBindVideo): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPVisionModelWithProjection): + nn.init.normal_( + module.visual_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, CLIPTextModelWithProjection): + nn.init.normal_( + module.text_projection.weight, + std=self.config.hidden_size**-0.5 * self.config.initializer_factor, + ) + + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CLIPEncoder): + module.gradient_checkpointing = value + + +CLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`CLIPConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +CLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, config: LanguageBindVideoConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + causal_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +class CLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = CLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class CLIPVisionTransformer(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + vl_new = getattr(config, 'clip_type', 'vl') == 'vl_new' + add_time_attn = config.add_time_attn + # self.embeddings = CLIPVisionEmbeddings(config) + if add_time_attn: + if vl_new: + self.embeddings = CLIPVisionEmbeddings3D(config) + else: + self.embeddings = CLIPVisionEmbeddings(config) + + self.patch_dropout = PatchDropout(config.force_patch_dropout) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = CLIPEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + ###################################### + if len(pixel_values.shape) == 7: + b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape + # print(pixel_values.shape) + B = b_new * pair_new * bs_new + pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new) + + elif len(pixel_values.shape) == 5: + B, _, T, _, _ = pixel_values.shape + # print(pixel_values.shape) + pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w') + else: + # print(pixel_values.shape) + B, _, _, _ = pixel_values.shape + T = 1 + ########################### + hidden_states = self.embeddings(pixel_values) + + hidden_states = self.patch_dropout(hidden_states, B, T) ############################################## + + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################ + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The vision model from CLIP without any head or projection on top.""", + CLIP_START_DOCSTRING, +) +class CLIPVisionModel(CLIPPreTrainedModel): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: CLIPVisionConfig): + super().__init__(config) + self.vision_model = CLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPVisionModel + + >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(CLIP_START_DOCSTRING) +class LanguageBindVideo(CLIPPreTrainedModel): + config_class = LanguageBindVideoConfig + + def __init__(self, config: LanguageBindVideoConfig): + super().__init__(config) + + if not isinstance(config.text_config, CLIPTextConfig): + raise ValueError( + "config.text_config is expected to be of type CLIPTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, CLIPVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type CLIPVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + self.add_time_attn = vision_config.add_time_attn + self.lora_r = vision_config.lora_r + self.lora_alpha = vision_config.lora_alpha + self.lora_dropout = vision_config.lora_dropout + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer(text_config) + self.vision_model = CLIPVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) + + # Initialize weights and apply final processing + self.post_init() + self.convert_to_lora() + # self.resize_pos(self.vision_model.embeddings, vision_config) + + def convert_to_lora(self): + if self.lora_r == 0: + return + if self.add_time_attn: + target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj", + "temporal_attn.q_proj", "temporal_attn.out_proj", + "temporal_mlp.fc1", "temporal_mlp.fc2"] + else: + target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"] + config = LoraConfig( + r=self.lora_r, # 16 + lora_alpha=self.lora_alpha, # 16 + target_modules=target_modules, # self_attn.out_proj + lora_dropout=self.lora_dropout, # 0.1 + bias="none", + modules_to_save=[], + ) + self.vision_model.encoder.is_gradient_checkpointing = False + self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config) + + def resize_pos(self, m, vision_config): + # convert embedding + if vision_config.num_mel_bins!=0 and vision_config.target_length!=0: + m.image_size = [vision_config.num_mel_bins, vision_config.target_length] + m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size + # pos resize + old_pos_embed_state_dict = m.position_embedding.state_dict() + old_pos_embed = old_pos_embed_state_dict['weight'] + dtype = old_pos_embed.dtype + grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size] + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + # m.to(args.device) + return + + m.num_patches = grid_size[0] * grid_size[1] + m.num_positions = m.num_patches + 1 + m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1))) + new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim) + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2 + + # if is_master(args): + # logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode='bicubic', + antialias=True, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype) + m.position_embedding = new_position_embedding + m.position_embedding.load_state_dict(old_pos_embed_state_dict) + + # m.to(args.device) + + @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`CLIPTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`CLIPVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindVideoConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CLIPOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPModel + + >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = clip_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return CLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) \ No newline at end of file diff --git a/languagebind/video/processing_video.py b/languagebind/video/processing_video.py new file mode 100644 index 0000000000000000000000000000000000000000..fdea0fd4fffa8eb6d4fff6b600ee02e7abe45c06 --- /dev/null +++ b/languagebind/video/processing_video.py @@ -0,0 +1,161 @@ +import cv2 +import decord +import numpy as np +import torch +from PIL import Image +from decord import VideoReader, cpu +from torchvision import transforms +from transformers import ProcessorMixin, BatchEncoding +from transformers.image_processing_utils import BatchFeature +from pytorchvideo.data.encoded_video import EncodedVideo +from torchvision.transforms import Compose, Lambda, ToTensor +from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo +from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample + +decord.bridge.set_bridge('torch') + +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) + +def make_list_of_images(x): + if not isinstance(x, list): + return [x] + return x + +def get_video_transform(config): + config = config.vision_config + if config.video_decode_backend == 'pytorchvideo': + transform = ApplyTransformToKey( + key="video", + transform=Compose( + [ + UniformTemporalSubsample(config.num_frames), + Lambda(lambda x: x / 255.0), + NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), + ShortSideScale(size=224), + CenterCropVideo(224), + RandomHorizontalFlipVideo(p=0.5), + ] + ), + ) + + elif config.video_decode_backend == 'decord': + + transform = Compose( + [ + # UniformTemporalSubsample(num_frames), + Lambda(lambda x: x / 255.0), + NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), + ShortSideScale(size=224), + CenterCropVideo(224), + RandomHorizontalFlipVideo(p=0.5), + ] + ) + + elif config.video_decode_backend == 'opencv': + transform = Compose( + [ + # UniformTemporalSubsample(num_frames), + Lambda(lambda x: x / 255.0), + NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), + ShortSideScale(size=224), + CenterCropVideo(224), + RandomHorizontalFlipVideo(p=0.5), + ] + ) + else: + raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv)') + return transform + + +def load_and_transform_video( + video_path, + transform, + video_decode_backend='opencv', + clip_start_sec=0.0, + clip_end_sec=None, + num_frames=8, +): + if video_decode_backend == 'pytorchvideo': + # decord pyav + video = EncodedVideo.from_path(video_path, decoder="decord", decode_audio=False) + duration = video.duration + start_sec = clip_start_sec # secs + end_sec = clip_end_sec if clip_end_sec is not None else duration # secs + video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec) + video_outputs = transform(video_data) + + elif video_decode_backend == 'decord': + decord.bridge.set_bridge('torch') + decord_vr = VideoReader(video_path, ctx=cpu(0)) + duration = len(decord_vr) + frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) + video_data = decord_vr.get_batch(frame_id_list) + video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) + video_outputs = transform(video_data) + + elif video_decode_backend == 'opencv': + cv2_vr = cv2.VideoCapture(video_path) + duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) + + video_data = [] + for frame_idx in frame_id_list: + cv2_vr.set(1, frame_idx) + _, frame = cv2_vr.read() + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + video_data.append(torch.from_numpy(frame).permute(2, 0, 1)) + cv2_vr.release() + video_data = torch.stack(video_data, dim=1) + video_outputs = transform(video_data) + else: + raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv)') + return video_outputs + +class LanguageBindVideoProcessor(ProcessorMixin): + attributes = [] + tokenizer_class = ("LanguageBindVideoTokenizer") + + def __init__(self, config, tokenizer=None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.transform = get_video_transform(config) + self.image_processor = load_and_transform_video + self.tokenizer = tokenizer + + def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs): + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer(text, max_length=context_length, padding='max_length', + truncation=True, return_tensors=return_tensors, **kwargs) + + if images is not None: + images = make_list_of_images(images) + image_features = [self.image_processor(image, self.transform, + video_decode_backend=self.config.vision_config.video_decode_backend, + num_frames=self.config.vision_config.num_frames) for image in images] + image_features = torch.stack(image_features) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features + return encoding + elif text is not None: + return encoding + else: + return {"pixel_values": image_features} + + def batch_decode(self, skip_special_tokens=True, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) + + def decode(self, skip_special_tokens=True, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) diff --git a/languagebind/video/tokenization_video.py b/languagebind/video/tokenization_video.py new file mode 100644 index 0000000000000000000000000000000000000000..2864429c098770fd37fd61e8a7b82d1fee5b12dd --- /dev/null +++ b/languagebind/video/tokenization_video.py @@ -0,0 +1,77 @@ +from transformers import CLIPTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "lb203/LanguageBind-Video": "https://huggingface.co/lb203/LanguageBind-Video/resolve/main/vocab.json", + }, + "merges_file": { + "lb203/LanguageBind-Video": "https://huggingface.co/lb203/LanguageBind-Video/resolve/main/merges.txt", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "lb203/LanguageBind-Video": 77, +} + + +PRETRAINED_INIT_CONFIGURATION = { + "lb203/LanguageBind-Video": {}, +} + +class LanguageBindVideoTokenizer(CLIPTokenizer): + """ + Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `<|endoftext|>`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `<|startoftext|>`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `<|endoftext|>`): + The end of sequence token. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token="<|startoftext|>", + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", # hack to enable padding + **kwargs, + ): + super(LanguageBindVideoTokenizer, self).__init__( + vocab_file, + merges_file, + errors, + unk_token, + bos_token, + eos_token, + pad_token, # hack to enable padding + **kwargs,) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e07b9a60bdce127bef7bb272b1d11d332591f8e0 --- /dev/null +++ b/main.py @@ -0,0 +1,659 @@ +import glob +import logging +import os +import re +import subprocess +import sys +import random +from datetime import datetime + +import numpy as np +import torch +from torch import optim +from torch.cuda.amp import GradScaler +from transformers import CLIPPreTrainedModel + +from a_cls.zeroshot_cls import evaluate_a_cls +from al_ret.retrieval import evaluate_al_ret +from i_cls.zeroshot_cls import evaluate_i_cls +from d_cls.zeroshot_cls import evaluate_d_cls +from t_cls.zeroshot_cls import evaluate_t_cls +from v_cls.zeroshot_cls import evaluate_v_cls +from vl_ret.retrieval import evaluate_vl_ret + +from model.process_clip import set_global_value, print_trainable_parameters + +try: + import wandb +except ImportError: + wandb = None + +try: + import tensorboardX as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +from data.build_datasets import get_data +from open_clip import create_model_and_transforms, create_loss +from training.distributed import is_master, init_distributed_device, broadcast_object +from training.logger import setup_logging +from training.params import parse_args +from training.scheduler import cosine_lr, const_lr, const_lr_cooldown +from training.file_utils import pt_load, start_sync_process, remote_sync +from train import train_one_epoch +from model.build_model import create_vat_model + +LATEST_CHECKPOINT_NAME = "epoch_latest.pt" +MODEL_DICT = {"ViT-L-14": "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K", + "ViT-H-14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"} +CHECKPOINT_DICT = {"ViT-L-14": "models--laion--CLIP-ViT-L-14-DataComp.XL-s13B-b90K/snapshots/84c9828e63dc9a9351d1fe637c346d4c1c4db341/pytorch_model.bin", + "ViT-H-14": "models--laion--CLIP-ViT-H-14-laion2B-s32B-b79K/snapshots/94a64189c3535c1cb44acfcccd7b0908c1c8eb23/pytorch_model.bin"} + + + + + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def get_latest_checkpoint(path: str, remote: bool): + # as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders + if remote: + result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + print(result) + if result.returncode == 1: + return None + checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]] + else: + checkpoints = glob.glob(path + '**/*.pt', recursive=True) + if checkpoints: + checkpoints = sorted(checkpoints, key=natural_key) + return checkpoints[-1] + return None + +def SET_GLOBAL_VALUE(k, v): + set_global_value(k, v) + +def main(args): + args = parse_args(args) + + # SET_GLOBAL_VALUE('PATCH_DROPOUT', args.force_patch_dropout) + # SET_GLOBAL_VALUE('NUM_FRAMES', args.num_frames) + + if torch.cuda.is_available(): + # This enables tf32 on Ampere GPUs which is only 8% slower than + # float16 and almost as accurate as float32 + # This was a default in pytorch until 1.12 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + # fully initialize distributed device environment + device = init_distributed_device(args) + + # get the name of the experiments + if args.name is None: + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + model_name_safe = args.model.replace('/', '-') + date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + if args.distributed: + # sync date_str from master to all ranks + date_str = broadcast_object(args, date_str) + args.name = '-'.join([ + date_str, + f"pt_{args.clip_type}", + f"text_{args.text_type}", + f"bs_{args.batch_size}", + f"ep_{args.epochs}", + f"mask_{args.force_patch_dropout}", + f"lorar_{args.lora_r}" if args.convert_to_lora else "", + f"lr_{args.lr}", + f"coeflr_{args.coef_lr}", + f"warm_{args.warmup}", + f"accum_{args.accum_freq}", + f"tattn_{args.add_time_attn}" if args.clip_type == 'vl' else "", + f"model_{model_name_safe}", + f"frm_{args.num_frames}", + f"vdb_{args.video_decode_backend}", + ]) + args.pretrained = CHECKPOINT_DICT[args.model] + args.model = MODEL_DICT[args.model] + + resume_latest = args.resume == 'latest' + log_base_path = os.path.join(args.logs, args.name) + args.log_base_path = log_base_path + args.log_path = None + if is_master(args, local=args.log_local): + os.makedirs(log_base_path, exist_ok=True) + log_filename = f'out-{args.rank}' if args.log_local else 'out.log' + args.log_path = os.path.join(log_base_path, log_filename) + if os.path.exists(args.log_path) and not resume_latest: + print( + "Error. Experiment already exists. Use --name {} to specify a new experiment." + ) + return -1 + + # Setup text logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # Setup wandb, tensorboard, checkpoint logging + args.wandb = 'wandb' in args.report_to or 'all' in args.report_to + args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to + args.checkpoint_path = os.path.join(log_base_path, "checkpoints") + if is_master(args): + args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else '' + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = '' + + if resume_latest: + resume_from = None + checkpoint_path = args.checkpoint_path + # If using remote_sync, need to check the remote instead of the local checkpoints folder. + if args.remote_sync is not None: + checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints") + if args.save_most_recent: + print('Error. Cannot use save-most-recent with remote_sync and resume latest.') + return -1 + if args.remote_sync_protocol != 's3': + print('Error. Sync protocol not supported when using resume latest.') + return -1 + if is_master(args): + # Checking for existing checkpoint via master rank only. It is possible for + # different rank processes to see different files if a shared file-system is under + # stress, however it's very difficult to fully work around such situations. + if args.save_most_recent: + # if --save-most-recent flag is set, look for latest at a fixed filename + resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME) + if not os.path.exists(resume_from): + # If no latest checkpoint has been saved yet, don't try to resume + resume_from = None + else: + # otherwise, list checkpoint dir contents and pick the newest checkpoint + resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None) + if resume_from: + logging.info(f'Found latest resume checkpoint at {resume_from}.') + else: + logging.info(f'No latest resume checkpoint found in {checkpoint_path}.') + if args.distributed: + # sync found checkpoint path to all ranks + resume_from = broadcast_object(args, resume_from) + args.resume = resume_from + + if args.copy_codebase: + copy_codebase(args) + + # start the sync proces if remote-sync is not None + remote_sync_process = None + if is_master(args) and args.remote_sync is not None: + # first make sure it works + result = remote_sync( + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + if result: + logging.info('remote sync successful.') + else: + logging.info('Error: remote sync failed. Exiting.') + return -1 + # if all looks good, start a process to do this every args.remote_sync_frequency seconds + remote_sync_process = start_sync_process( + args.remote_sync_frequency, + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + remote_sync_process.start() + + if args.precision == 'fp16': + logging.warning( + 'It is recommended to use AMP mixed-precision instead of FP16. ' + 'FP16 support needs further verification and tuning, especially for train.') + + if args.horovod: + logging.info( + f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.' + f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') + elif args.distributed: + logging.info( + f'Running in distributed mode with multiple processes. Device: {args.device}.' + f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') + else: + logging.info(f'Running with a single process. Device {args.device}.') + + dist_model = None + args.distill = args.distill_model is not None and args.distill_pretrained is not None + if args.distill: + # FIXME: support distillation with grad accum. + assert args.accum_freq == 1 + # FIXME: support distillation with coca. + assert 'coca' not in args.model.lower() + + if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1: + # arg is nargs, single (square) image size list -> int + args.force_image_size = args.force_image_size[0] + random_seed(args.seed, 0) + + ############################################################################# + # model, preprocess_train, preprocess_val = create_model_and_transforms( + # args.model, + # args.pretrained, + # precision=args.precision, + # device=device, + # jit=args.torchscript, + # force_quick_gelu=args.force_quick_gelu, + # force_custom_text=args.force_custom_text, + # force_patch_dropout=args.force_patch_dropout, + # force_image_size=args.force_image_size, + # pretrained_image=args.pretrained_image, + # image_mean=args.image_mean, + # image_std=args.image_std, + # aug_cfg=args.aug_cfg, + # output_dict=True, + # ) + + model = create_vat_model(args) + args.image_size = model.vision_model.config.image_size + ############################################################################# + + + + if args.distill: + # FIXME: currenlty assumes the model your distilling from has the same tokenizer & transforms. + dist_model, _, _ = create_model_and_transforms( + args.distill_model, + args.distill_pretrained, + device=device, + precision=args.precision, + output_dict=True, + ) + if args.use_bnb_linear is not None: + print('=> using a layer from bitsandbytes.\n' + ' this is an experimental feature which requires two extra pip installs\n' + ' pip install bitsandbytes triton' + ' please make sure to use triton 2.0.0') + import bitsandbytes as bnb + from open_clip.utils import replace_linear + print(f'=> replacing linear layers with {args.use_bnb_linear}') + linear_replacement_cls = getattr(bnb.nn.triton_based_modules, args.use_bnb_linear) + replace_linear(model, linear_replacement_cls) + model = model.to(device) + + random_seed(args.seed, args.rank) + + # if args.trace: + # model = trace_model(model, batch_size=args.batch_size, device=device) + if args.lock_image: + for param in model.vision_model.embeddings.parameters(): + param.requires_grad = False + for param in model.vision_model.pre_layrnorm.parameters(): + param.requires_grad = False + + if not args.convert_to_lora: + for param in model.vision_model.embeddings.parameters(): + param.requires_grad = False + for param in model.vision_model.pre_layrnorm.parameters(): + param.requires_grad = False + if args.add_time_attn: + for name, param in model.vision_model.encoder.layers.named_parameters(): + if 'temporal' in name: + param.requires_grad = True + else: + param.requires_grad = False + else: + for name, param in model.vision_model.encoder.layers.named_parameters(): + if 'self_attn' in name: + param.requires_grad = True + else: + param.requires_grad = False + else: + if args.add_time_attn: + for name, param in model.vision_model.encoder.layers.named_parameters(): + if 'temporal_embedding' in name or 'temporal_layer_norm1' in name: + param.requires_grad = True + + for param in model.vision_model.embeddings.position_embedding.parameters(): + param.requires_grad = False + model.vision_model.embeddings.class_embedding.requires_grad = True + + + if args.lock_text: + for param in model.text_model.parameters(): + param.requires_grad = False + for param in model.text_projection.parameters(): + param.requires_grad = False + + model.logit_scale.requires_grad = args.learn_temp + + if is_master(args): + print_trainable_parameters(model, msg='The model: ') + + if args.grad_checkpointing: + # model.text_model.encoder.gradient_checkpointing = args.grad_checkpointing + model.vision_model.encoder.gradient_checkpointing = args.grad_checkpointing + # if args.clip_type == 'vl_new': + # for m in model.vision_model.encoder.layers: + # m.gradient_checkpointing = args.grad_checkpointing + # elif args.clip_type == 'al': + # model.vision_model.encoder.gradient_checkpointing = args.grad_checkpointing + # for m in model.vision_model.encoder.layers: + # m.gradient_checkpointing = False + + + if is_master(args): + logging.info("Model:") + # logging.info(f"{str(model)}") + logging.info("Args:") + args_file = os.path.join(args.logs, args.name, "args.txt") + with open(args_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args['static_graph'] = True + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) + + if args.distill: + dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args) + + # create optimizer and scaler + ############################################################ + # if args.train_data or args.dataset_type == "synthetic": + assert not args.trace, 'Cannot train with traced model' + + no_decay = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n or 'class_embedding' in n or 'patch_embedding' in n + decay = lambda n, p: not no_decay(n, p) + + lora = lambda n, p: "lora" in n + non_lora = lambda n, p: not lora(n, p) + + named_parameters = list(model.named_parameters()) + no_decay_non_lora_params = [[n, p] for n, p in named_parameters if no_decay(n, p) and non_lora(n, p) and p.requires_grad] + decay_non_lora_params = [[n, p] for n, p in named_parameters if decay(n, p) and non_lora(n, p) and p.requires_grad] + + no_decay_lora_params = [[n, p] for n, p in named_parameters if no_decay(n, p) and lora(n, p) and p.requires_grad] + decay_lora_params = [[n, p] for n, p in named_parameters if decay(n, p) and lora(n, p) and p.requires_grad] + + + param_groups = [] + if no_decay_non_lora_params: param_groups.append({"params": [p for n, p in no_decay_non_lora_params], "weight_decay": 0., 'lr': args.lr * args.coef_lr}) + if decay_non_lora_params: param_groups.append({"params": [p for n, p in decay_non_lora_params], "weight_decay": args.wd, 'lr': args.lr * args.coef_lr}) + if no_decay_lora_params: param_groups.append({"params": [p for n, p in no_decay_lora_params], "weight_decay": 0.}) + if decay_lora_params: param_groups.append({"params": [p for n, p in decay_lora_params], "weight_decay": args.wd}) + + optimizer = optim.AdamW( + # [ + # {"params": no_decay_non_visual_params, "weight_decay": 0.}, + # {"params": decay_non_visual_params, "weight_decay": args.wd}, + # {"params": no_decay_visual_params, "weight_decay": 0., 'lr': args.lr * args.coef_lr}, + # {"params": decay_visual_params, "weight_decay": args.wd, 'lr': args.lr * args.coef_lr}, + # ], + param_groups, + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + ) + + name_groups = {} + if no_decay_non_lora_params: + name_groups['no_decay_non_lora_params'] = [{"name": n, "weight_decay": 0., 'lr': args.lr * args.coef_lr} for n, p in no_decay_non_lora_params] + if decay_non_lora_params: + name_groups['decay_non_lora_params'] = [{"name": n, "weight_decay": args.wd, 'lr': args.lr * args.coef_lr} for n, p in decay_non_lora_params] + if no_decay_lora_params: + name_groups['no_decay_lora_params'] = [{"name": n, "weight_decay": 0., 'lr': args.lr} for n, p in no_decay_lora_params] + if decay_lora_params: + name_groups['decay_lora_params'] = [{"name": n, "weight_decay": args.wd, 'lr': args.lr} for n, p in decay_lora_params] + if is_master(args): + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for group_name, group in name_groups.items(): + logging.info(f"Group name: {group_name}:") + f.write(f"Group name: {group_name}:\n") + for i in group: + logging.info(f"Parameter name: {i['name']}. Learning rate: {i['lr']}. Weight decay: {i['weight_decay']}") + f.write(f"Parameter name: {i['name']}. Learning rate: {i['lr']}. Weight decay: {i['weight_decay']}\n") + + + if args.horovod: + optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + scaler = GradScaler() if args.precision == "amp" else None + ############################################################ + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + checkpoint = pt_load(args.resume, map_location='cpu') + if 'epoch' in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith('module'): + sd = {k[len('module.'):]: v for k, v in sd.items()} + miss, unexpect = model.load_state_dict(sd, strict=False) + # print(miss, unexpect) + assert unexpect == [] or unexpect == ['text_model.embeddings.position_ids'] or unexpect == ['module.text_model.embeddings.position_ids'] + if unexpect == ['text_model.embeddings.position_ids'] or unexpect == ['module.text_model.embeddings.position_ids']: + logging.warning(f"Unexpected key: {unexpect}") + if optimizer is not None: + if args.do_train: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and 'scaler' in checkpoint: + scaler.load_state_dict(checkpoint['scaler']) + logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") + + # initialize datasets + data = get_data(args, epoch=start_epoch) + if is_master(args): + logging.info(f"{data})") + assert len(data), 'At least one train or eval dataset must be specified.' + + # create scheduler if train + scheduler = None + if f'{args.clip_type}_pt' in data and optimizer is not None: + total_steps = (data[f'{args.clip_type}_pt'].dataloader.num_batches // args.accum_freq) * args.epochs + if args.lr_scheduler == "cosine": + scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) + elif args.lr_scheduler == "const": + scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps) + elif args.lr_scheduler == "const-cooldown": + assert args.epochs_cooldown is not None, \ + "Please specify the number of cooldown epochs for this lr schedule." + cooldown_steps = (data[f'{args.clip_type}_pt'].dataloader.num_batches // args.accum_freq) * args.epochs_cooldown + scheduler = const_lr_cooldown( + optimizer, args.lr, args.warmup, total_steps, + cooldown_steps, args.lr_cooldown_power, args.lr_cooldown_end) + else: + logging.error( + f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.') + exit(1) + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + # if args.wandb and is_master(args): + # assert wandb is not None, 'Please install wandb.' + # logging.debug('Starting wandb.') + # args.train_sz = data["train"].dataloader.num_samples + # if args.val_data is not None: + # args.val_sz = data["val"].dataloader.num_samples + # # you will have to configure this for your project! + # wandb.init( + # project=args.wandb_project_name, + # name=args.name, + # id=args.name, + # notes=args.wandb_notes, + # tags=[], + # resume='auto' if args.resume == "latest" else None, + # config=vars(args), + # ) + # if args.debug: + # wandb.watch(model, log='all') + # wandb.save(params_file) + # logging.debug('Finished loading wandb.') + + if args.torchcompile: + logging.info('Compiling model...') + model = torch.compile(model) + + if f'{args.clip_type}_pt' not in data: + # If using int8, convert to inference mode. + if args.use_bnb_linear is not None: + from open_clip.utils import convert_int8_model_to_inference_mode + convert_int8_model_to_inference_mode(model) + # Evaluate. + if "i_cls" in data: + evaluate_i_cls(model, data, start_epoch, args, writer) + if "vl_ret" in data: + for sub_data in data['vl_ret']: + evaluate_vl_ret(model, sub_data, start_epoch, args, writer) + if "a_cls" in data: + for sub_data in data['a_cls']: + evaluate_a_cls(model, sub_data, start_epoch, args, writer) + if "al_ret" in data: + for sub_data in data['al_ret']: + evaluate_al_ret(model, sub_data, start_epoch, args, writer) + if "v_cls" in data: + for sub_data in data['v_cls']: + evaluate_v_cls(model, sub_data, start_epoch, args, writer) + if "d_cls" in data: + for sub_data in data['d_cls']: + evaluate_d_cls(model, sub_data, start_epoch, args, writer) + if "t_cls" in data: + for sub_data in data['t_cls']: + evaluate_t_cls(model, sub_data, start_epoch, args, writer) + return + + loss = create_loss(args) + + for epoch in range(start_epoch, args.epochs): + if is_master(args): + logging.info(f'Start epoch {epoch}') + + train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer) + completed_epoch = epoch + 1 + + if "i_cls" in data: + evaluate_i_cls(model, data, completed_epoch, args, writer) + if "vl_ret" in data: + for sub_data in data['vl_ret']: + evaluate_vl_ret(model, sub_data, completed_epoch, args, writer) + if "a_cls" in data: + for sub_data in data['a_cls']: + evaluate_a_cls(model, sub_data, completed_epoch, args, writer) + if "al_ret" in data: + for sub_data in data['al_ret']: + evaluate_al_ret(model, sub_data, completed_epoch, args, writer) + if "v_cls" in data: + for sub_data in data['v_cls']: + evaluate_v_cls(model, sub_data, completed_epoch, args, writer) + if "d_cls" in data: + for sub_data in data['d_cls']: + evaluate_d_cls(model, sub_data, completed_epoch, args, writer) + if "t_cls" in data: + for sub_data in data['t_cls']: + evaluate_t_cls(model, sub_data, completed_epoch, args, writer) + + # Saving checkpoints. + if args.save_logs: + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.delete_previous_checkpoint: + previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt") + if os.path.exists(previous_checkpoint): + os.remove(previous_checkpoint) + + if args.save_most_recent: + # try not to corrupt the latest checkpoint if save fails + tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt") + latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME) + torch.save(checkpoint_dict, tmp_save_path) + os.replace(tmp_save_path, latest_save_path) + + if args.wandb and is_master(args): + wandb.finish() + + # run a final sync. + if remote_sync_process is not None: + logging.info('Final remote sync.') + remote_sync_process.terminate() + result = remote_sync( + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + if result: + logging.info('Final remote sync successful.') + else: + logging.info('Final remote sync failed.') + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb')) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/modality_generation_codes/depth_ddp_glpn.py b/modality_generation_codes/depth_ddp_glpn.py new file mode 100644 index 0000000000000000000000000000000000000000..136027c7912f3ecb79cf268747da2433edcdf7c7 --- /dev/null +++ b/modality_generation_codes/depth_ddp_glpn.py @@ -0,0 +1,547 @@ +import sys + +from PIL import Image +from torchvision import transforms +# from transformers import OFATokenizer, OFAModel +# from transformers.models.ofa.generate import sequence_generator # from generate import sequence_generator +from transformers import DPTImageProcessor, DPTForDepthEstimation +import os.path +from argparse import ArgumentParser +from torch.utils import data +import json +import torch +import torch.distributed as dist +import os +import os.path as osp +from os.path import join as opj +import pandas as pd +from random import randint +import cv2 +import torch +from torch.utils.data import Dataset, DataLoader +import decord +import glob +import subprocess +import time +import numpy as np + +import os +os.environ["HF_DATASETS_OFFLINE"] = "1" + +import decord +from decord import cpu +#glpn +from transformers import GLPNFeatureExtractor, GLPNForDepthEstimation +import torch +import numpy as np +from PIL import Image +import requests +import io + +import cv2 +import numpy as np +from decord import VideoReader, cpu + +try: + from petrel_client.client import Client + petrel_backend_imported = True +except (ImportError, ModuleNotFoundError): + petrel_backend_imported = False + + +def get_video_loader(use_petrel_backend: bool = True, + enable_mc: bool = True, + conf_path: str = None): + if petrel_backend_imported and use_petrel_backend: + _client = Client(conf_path=conf_path, enable_mc=enable_mc) + else: + _client = None + + def _loader(video_path): + if _client is not None and 's3:' in video_path: + video_path = io.BytesIO(_client.get(video_path)) + + vr = VideoReader(video_path, num_threads=1, ctx=cpu(0)) + return vr + + return _loader + +class my_dataset(Dataset): + def __init__(self, args): + super().__init__() + self.args = args + self.shuffle = True + self.resolution = args.resolution # 对于动态大小视频无用 + self.loader = get_video_loader() + + if args.train_file.endswith('.csv'): + self.train_file = pd.read_csv(args.train_file) + elif args.train_file.endswith('.json'): + # coco_vat_vat0_11_all_id_rootfolder_clsidx_spacy.json + # 格式: id : { 'idx_list' : [0], 'root_folder' : 'coco_vat_9' } + + if hasattr(args, 'part_nums') and args.part_nums >1: + self.part_nums = args.part_nums + else: + self.part_nums = 100000 + self.part_index = args.part_index + t1 = time.time() + with open(args.train_file, 'r', encoding='utf-8') as f: + self.train_file = json.load(f) + if type(self.train_file) is str: + self.train_file = json.loads(self.train_file) + + self.id_list = list(self.train_file.keys()) + #============================= + # obtain subset of self.id_list + self.id_list = self.id_list[self.part_nums*(self.part_index-1):self.part_nums*self.part_index] + + + print(f'Nums of train_file is {len(self.id_list)},part_index:{self.part_index}, first:{self.id_list[0]}') + self.no_caption_id_list = [] + for idx,id in enumerate(self.id_list): + caption_json = osp.join('/apdcephfs_cq3/share_1311970/A_Youtube',self.train_file[id]['root_folder'],f'{id}_depth_f8glpn_folder') + mp4_path = osp.join('/apdcephfs_cq3/share_1311970/A_Youtube',self.train_file[id]['root_folder'],f'{id}.mp4') + if not os.path.exists(caption_json) and os.path.exists(mp4_path): + self.no_caption_id_list.append(mp4_path) + # else: + # print(f'{caption_json} is exist!') + if idx%10000==0: + print(f'Time_cost:{time.time()-t1}s, idx:{idx}, caption_json:{caption_json}') + try: + print(f'Nums of no_depth_folder_id_list is {len(self.no_caption_id_list)}, first:{self.no_caption_id_list[0]}') + except: + print(f'Nums of no_depth_folder_id_list is {len(self.no_caption_id_list)}') + t2 = time.time() + print(f'Time cost:{t2-t1}s') + # DPT + # self.patch_resize_transform = DPTImageProcessor.from_pretrained("Intel/dpt-large", cache_dir= args.weights_folder) + # glpn + self.patch_resize_transform = GLPNFeatureExtractor.from_pretrained("vinvino02/glpn-nyu", cache_dir= args.weights_folder) + print('Dataset nums is {}'.format(self.__len__())) + time.sleep(10) + + + def __len__(self): + return len(self.no_caption_id_list) + + def random_sample(self): + return self.__getitem__(randint(0, self.__len__() - 1)) + + def sequential_sample(self, idx): + if idx >= self.__len__() - 1: + return self.__getitem__(0) + return self.__getitem__(idx + 1) + + def skip_sample(self, idx): + if self.shuffle: + return self.random_sample() + return self.sequential_sample(idx=idx) + + def resize_frame(self, frame): + height, width = frame.shape[:2] + if height < width: + new_height = 256 + new_width = int(width * (new_height / height)) + else: + new_width = 256 + new_height = int(height * (new_width / width)) + # resized_frame = cv2.resize(frame, (new_width, new_height)) + resized_frame = cv2.resize(frame, (448, 796)) # 576*448 796,448 + return resized_frame, new_width, new_height # frame的形状是和new_w, new_w不一样的!! + + + def get_frames_from_video_decord(self, batchsize=1, video_path=None, caption_nums_per_video=8): + # 加载视频 + video_path = video_path + vr = self.loader(video_path) + frame_width, frame_height = vr[0].shape[1], vr[0].shape[0] + + # 确定要提取的帧数 + num_frames = caption_nums_per_video + # 计算每隔多少帧提取一次 + total_frames = len(vr) + step = total_frames // num_frames + + # 用于存储提取的图像的tensor + frames = [] + height = width = 0 + + # 直接读取指定帧 + for i in range(num_frames): + # 计算要提取的帧的索引 + idx = i * step + + # 读取该帧 + frame = vr[idx].asnumpy() + frame, new_width, new_height = self.resize_frame(frame) + + if i == 0: + height, width = new_height, new_width + + # 转换为PIL Image并进行缩放 + # Image_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + Image_frame = Image.fromarray(frame) + frame = self.patch_resize_transform(images=Image_frame, return_tensors="pt").pixel_values.unsqueeze(0) + + # 将numpy数组转换为tensor并存储在frames中 + frames.append(frame) + + # vr.close() + frames = torch.cat(frames, 1).squeeze(0) + return frames, height, width + + + def get_frames_from_video_opencv(self, batchsize=1, video_path=None, caption_nums_per_video=8): + # 加载视频 + video_path = video_path + cap = cv2.VideoCapture(video_path) + frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + + # 确定要提取的帧数 + num_frames = caption_nums_per_video + # 计算每隔多少帧提取一次 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + step = total_frames // num_frames + + # 用于存储提取的图像的tensor + # frames = torch.empty(num_frames, 3, frame_height, frame_width) + frames = [] + height = width = 0 + # frame_idx_list=[] 把视频的帧序号存储下来 + # 直接读取指定帧 + for i in range(num_frames): + # 计算要提取的帧的索引 + idx = i * step + # frame_idx_list.append(idx) + # 设置当前帧为所需的帧 + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + # 读取该帧 + ret, frame = cap.read() + frame, new_width, new_height = self.resize_frame(frame) # frame size是按照(w=448, h=576)resize的,但是new_h, new_w是按照原视频宽高比例缩放到短边为256,这样可以保持视频物体比例并且减少内存占用 + # print(f'{frame_width},{frame_height }===== {frame.shape} ') + if i==0: + # height, width, _ = frame.shape + height, width = new_height, new_width + # print(f'{frame.shape}, {new_width}, {new_height}!!!') #(576, 448, 3), 256, 455!!! + if not ret: + break + # ret, frame = cap.read() # frame.shape (h,w,3) + + # frame = self.resize_frame(frame) + + # # print(f'{frame_width},{frame_height }===== {frame.shape} ') + # if i==0: + # height, width, _ = frame.shape + # if not ret: + # break + + # 转换为PIL Image并进行缩放 + Image_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + # frame = self.patch_resize_transform(Image_frame) + frame = self.patch_resize_transform(images=Image_frame, return_tensors="pt").pixel_values.unsqueeze(0) + # print(frame.shape) + + # 将numpy数组转换为tensor并存储在frames中 + # frames[i] = frame + frames.append(frame) + + # 打印输出frames的形状 + # print(frames.shape) + cap.release() + frames = torch.cat(frames, 1).squeeze(0) + return frames, height, width + + + + + def __getitem__(self, idx): + + try: + + # video_id = self.filter_train_file[idx] + # video_path = opj(self.vat_root, video_id) + video_path = self.no_caption_id_list[idx] + video_id = video_path.split('/')[-1].split('.')[0] + # 假如多个程序一起跑,其他已经生成了,就跳过 + caption_video_json = video_path.replace('.mp4', '_depth_f8glpn_folder') + if os.path.exists(caption_video_json): + print('parallel task has process it :{}'.format(caption_video_json)) + # return '===========', None, torch.random(8,3,self.resolution, self.resolution) + return self.skip_sample(idx) + if not osp.exists(video_path): + print('video {} is not exists and skip this idx! '.format(video_path)) + return self.skip_sample(idx) + video_frames, height, width = self.get_frames_from_video_opencv( video_path = video_path, caption_nums_per_video = args.caption_nums_per_video) + # video_frames, height, width = self.get_frames_from_video_decord( video_path = video_path, caption_nums_per_video = args.caption_nums_per_video) + return video_id, video_path, video_frames, height, width + + except Exception as e: + print('Read video error in {},{} and we have skip this !, this will not cause error!'.format(idx,e)) + return self.skip_sample(idx) + + +def synchronize(): + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +def depth_estimation(args): + + ######################################## model start ############################# + """https://huggingface.co/docs/transformers/main/en/model_doc/dpt""" + + weights_folder = args.weights_folder + print(f'args.weights_folder is {args.weights_folder}') + + model_name = 'glpn' + if model_name == 'glpn': + # glpn + feature_extractor = GLPNFeatureExtractor.from_pretrained("vinvino02/glpn-nyu", cache_dir= weights_folder) + model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-nyu", cache_dir= weights_folder).cuda(args.local_rank) + else: + # DPT + processor = DPTImageProcessor.from_pretrained("Intel/dpt-large", cache_dir= weights_folder) + model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large", cache_dir= weights_folder).cuda(args.local_rank) + + if args.rank==0: + print('模型初始化完成') + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank) + model.eval() + if args.rank==0: + print('DDP model') + ######################################## model over ############################# + + ######################################## dataset start ############################# + if args.rank == 0: + print('dataset 初始化') + + train_dataset = my_dataset(args) + if args.rank == 0: + print('dataset_len: ',train_dataset.__len__()) + print('loading dataset is complete!') + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=args.world_size, + rank=args.rank + ) + if args.rank == 0: + print('正在同步') + synchronize() + if args.rank == 0: + print('dataloader 初始化') + dataloader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + sampler=train_sampler, + drop_last=True + ) + ######################################## dataset over ############################# + ######################################## depth_estimation start ############################# + + for index, (video_ids, video_paths, videos_frames_, h_list, w_list) in enumerate(dataloader): + + bs, cap_nums, c, h, w = videos_frames_.shape + videos_frames = videos_frames_.view(-1, c, h, w ).cuda(args.local_rank) + torch.cuda.empty_cache() + try: + with torch.no_grad(): + outputs = model(videos_frames) + predicted_depth = outputs.predicted_depth # (bs*cap_nums, h, w) + + + predicted_depth = predicted_depth.view(bs, cap_nums, h, w) + # print(f'predicted_depth.shape:{predicted_depth.shape}') + # interpolate to original size + for bs_idx, sample in enumerate(predicted_depth): + # import ipdb + # ipdb.set_trace() + pic_folder = video_paths[bs_idx].replace('.mp4','_depth_f8glpn_folder') + os.makedirs(pic_folder, exist_ok=True) + for frame_idx, frame in enumerate(predicted_depth[bs_idx]): + prediction = torch.nn.functional.interpolate( + frame.unsqueeze(0).unsqueeze(0), # torch.Size([1, 1, 384, 384]) + size=(h_list[bs_idx],w_list[bs_idx]), + mode="bicubic", + align_corners=False, + ) # torch.Size([1, 1, h=480, w=640]) + # print('prediction.shape:{prediction.shape}') + # visualize the prediction + output = prediction.squeeze().cpu().numpy() + # formatted = (output * 255 / np.max(output)).astype("uint8") + # depth = Image.fromarray(formatted) # size (576, 1024) + # depth.save(f"{pic_folder}/{frame_idx}.png") + max_depth = 10 + if np.any(output>10): + print(f"{pic_folder} > 10") + output_1k = np.clip(output, 0, max_depth)*1000 + cv2.imwrite(f"{pic_folder}/{frame_idx}.png", output_1k.astype("uint16"), [cv2.IMWRITE_PNG_COMPRESSION, 0]) + print(f'{pic_folder} is succeed!') + # sys.exit(0) + del videos_frames, outputs, predicted_depth + except Exception as e: + print(f'Error:{e}!') + del videos_frames + ######################################## depth_estimation over ############################# +def init_distributed_mode(args): + + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.local_rank = int(os.environ['LOCAL_RANK']) + torch.cuda.set_device(args.local_rank) + + + args.dist_backend = 'nccl' + args.dist_url = 'env://' + + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank, timeout=datetime.timedelta(seconds=5400)) + torch.distributed.barrier() + + +import utils.misc as misc + +def main(args): + misc.init_distributed_mode(args) + if args.rank == 0: + print('进程组初始化完成') + print("started") + print("started caption_json count!") + # glob1(json_path=args.exist_caption_id_list_json) # 'coco_vat_exist_caption_id_list_03141026.json' + ###########################################################3 + import time + t1=time.time() + depth_estimation(args) + t2 = time.time() + if args.rank == 0: + print('Time : ',t2-t1,' s') + dist.destroy_process_group() # 销毁进程组 + + +def test_dataset(args): + + train_dataset = my_dataset(args) + loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) + from time import time + for i, sample in enumerate(loader): + video_ids, video_paths, videos_frames, h, w = sample + print(i, video_ids, video_paths, videos_frames.shape, h, w) + +def glpn(): + + from transformers import GLPNFeatureExtractor, GLPNForDepthEstimation + import torch + import numpy as np + from PIL import Image + import requests + + max_depth = 10 + + # url = "http://images.cocodataset.org/val2017/000000039769.jpg" + # image = Image.open(requests.get(url, stream=True).raw) + image = Image.open('hallo.png') + weights_folder = './glpn' + feature_extractor = GLPNFeatureExtractor.from_pretrained("vinvino02/glpn-nyu", cache_dir= weights_folder) + model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-nyu", cache_dir= weights_folder) + + # prepare image for the model + inputs = feature_extractor(images=image, return_tensors="pt") + # import ipdb + # ipdb.set_trace() + # print(inputs.shape) + + with torch.no_grad(): + outputs = model(**inputs) + predicted_depth = outputs.predicted_depth + + # interpolate to original size + prediction = torch.nn.functional.interpolate( + predicted_depth.unsqueeze(1), + size=image.size[::-1], + mode="bicubic", + align_corners=False, + ) + + # visualize the prediction + output = prediction.squeeze().cpu().numpy() + # formatted = (output * 255 / np.max(output)).astype("uint8") + # depth = Image.fromarray(formatted) + # print(f'min_output:{torch.min(output)}, max_output:{torch.max(output)}!!') + output_1k = np.clip(output, 0, max_depth)*1000 + import ipdb + ipdb.set_trace() + cv2.imwrite('./saved_10000jpg.jpg', output_1k.astype("uint16"), [cv2.IMWRITE_PNG_COMPRESSION, 0]) + # cv2.imwrite('./image1.jpg', output_1k.astype("uint16"), [cv2.IMWRITE_PNG_COMPRESSION, 0]) + cv2.imwrite('./saved_10000png.png', output_1k.astype("uint16"), [cv2.IMWRITE_PNG_COMPRESSION, 0]) + # cv2.imwrite('./image2.png', output_1k.astype("uint16"), [cv2.IMWRITE_PNG_COMPRESSION, 0]) + + +# glpn() +if __name__ == "__main__": + + + import time + # time.sleep(10000) + import datetime + # 获取当前时间 + now = datetime.datetime.now() + # 获取当前月份 + month = now.month + # 获取当前日期 + day = now.day + # 获取当前小时 + hour = now.hour + + parser = ArgumentParser() + parser.add_argument('--caption_nums_per_video', type=int, default=8, help='process rank') + parser.add_argument('--batch_size', type=int, default=2) + parser.add_argument('--train_file', type=str, required=True) + parser.add_argument('--vat_root', type=str,default=None) + parser.add_argument('--num_workers', type=int, default=1) + parser.add_argument('--resolution', type=int, default=480) + # parser.add_argument('--exist_caption_id_list_json', type=str, default=f'coco_vat_exist_caption_id_list_{month}{day}{hour}.json',help='') + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') # --dist_on_itp ddp + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--distributed', default=True, help='url used to set up distributed training') + parser.add_argument('--gpus', default=[0, 1, 2, 3], help='DP CUDA devices') + parser.add_argument('--part_index', default=1, type=int, help='used to split train_file_id into different parts, and generate caption from part_index 1 to ....') + parser.add_argument('--part_nums', default=1000, type=int, help='used to split train_file_id into different parts, and generate caption from part_index 1 to ....') + parser.add_argument('--weights_folder', type=str,default='/apdcephfs_cq3/share_1311970/A_ofa/glpn') + args = parser.parse_args() + # test_dataset(args) + # import ipdb + # ipdb.set_trace() + main(args) + synchronize() + # success_file=f"part_{args.part_index}_success" + success_file=f"/apdcephfs_cq3/share_1311970/A_depth_glpn/part_{args.part_index}_success" + os.system(f"touch {success_file}") + + """ + + HF_DATASETS_OFFLINE=1 python3 -m torch.distributed.launch --nproc_per_node 1 --master_port 29504 depth_ddp_glpn.py \ + --train_file /apdcephfs_cq3/share_1311970/A_Youtube/coco_vat_vat0_11_all_id_rootfolder_clsidx_spacy.json \ + --num_workers 1 --batch_size 2 \ + --part_index 92 \ + --part_nums 10000 \ + --weights_folder /apdcephfs_cq3/share_1311970/A_ofa/glpn \ + --resolution 0 \ + --caption_nums_per_video 8 + + """ + + + diff --git a/modality_generation_codes/ofa_ddp.py b/modality_generation_codes/ofa_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..c39f10b34b7a92b02ed830c3fd2b4742c5b9cff4 --- /dev/null +++ b/modality_generation_codes/ofa_ddp.py @@ -0,0 +1,458 @@ + +from PIL import Image +from torchvision import transforms +from transformers import OFATokenizer, OFAModel +from transformers.models.ofa.generate import sequence_generator # from generate import sequence_generator +import os.path +from argparse import ArgumentParser +from torch.utils import data +import json +import torch +import torch.distributed as dist +import os +import os.path as osp +from os.path import join as opj +import pandas as pd +from random import randint +import cv2 +import torch +from torch.utils.data import Dataset, DataLoader +import decord +import glob +import subprocess +import time + + + + + +class my_dataset(Dataset): + def __init__(self, args): + super().__init__() + self.args = args + self.shuffle = True + self.resolution = args.resolution + + if args.train_file.endswith('.csv'): + self.train_file = pd.read_csv(args.train_file) + elif args.train_file.endswith('.json'): + # coco_vat_vat0_11_all_id_rootfolder_clsidx_spacy.json + # 格式: id : { 'idx_list' : [0], 'root_folder' : 'coco_vat_9' } + + if hasattr(args, 'part_nums') and args.part_nums >1: + self.part_nums = args.part_nums + else: + self.part_nums = 100000 + self.part_index = args.part_index + t1 = time.time() + with open(args.train_file, 'r', encoding='utf-8') as f: + self.train_file = json.load(f) + if type(self.train_file) is str: + self.train_file = json.loads(self.train_file) + + self.id_list = list(self.train_file.keys()) + #============================= + # obtain subset of self.id_list, so that deduplication time is less than 30min + self.id_list = self.id_list[self.part_nums*(self.part_index-1):self.part_nums*self.part_index] + + + print(f'Nums of train_file is {len(self.id_list)},part_index:{self.part_index}, first:{self.id_list[0]}') + self.no_caption_id_list = [] + self.exist_caption_path_list = {} + for idx, id in enumerate(self.id_list): + caption_json = osp.join('/apdcephfs_cq3/share_1311970/A_Youtube',self.train_file[id]['root_folder'],f'{id}_caption.json') + mp4_path = osp.join('/apdcephfs_cq3/share_1311970/A_Youtube',self.train_file[id]['root_folder'],f'{id}.mp4') + "existcap的数目包括video不存在的,所以有点虚大" + if not os.path.exists(caption_json) and os.path.exists(mp4_path): + self.no_caption_id_list.append(mp4_path) + else: + self.exist_caption_path_list[caption_json]=True + + if idx%10000==0: + print(f'Time_cost:{time.time()-t1}s, idx:{idx}, caption_json:{caption_json}') + print(f'Nums of no_caption_id_list is {len(self.no_caption_id_list)}, first:{self.no_caption_id_list[0]}') + print(f'Nums of exist_caption_path_list is {len(self.exist_caption_path_list)}') + if args.rank==0: + success_file=f"part_{args.part_index}_success_nocap_{len(self.no_caption_id_list)}_existcap{len(self.exist_caption_path_list)}" + os.system(f"touch {success_file}") + t2 = time.time() + print(f'Time cost:{t2-t1}s') + # print('======',self.exist_file_list,'====',self.no_caption_id_list) + + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert("RGB"), + transforms.Resize((self.resolution, self.resolution), interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std) + ]) + print('Dataset nums is {}'.format(self.__len__())) + time.sleep(10) + + + def __len__(self): + return len(self.no_caption_id_list) + + def random_sample(self): + return self.__getitem__(randint(0, self.__len__() - 1)) + + def sequential_sample(self, idx): + if idx >= self.__len__() - 1: + return self.__getitem__(0) + return self.__getitem__(idx + 1) + + def skip_sample(self, idx): + if self.shuffle: + return self.random_sample() + return self.sequential_sample(idx=idx) + + + def get_frames_from_video_opencv(self, batchsize=1, video_path=None, caption_nums_per_video=8): + # 加载视频 + video_path = video_path + cap = cv2.VideoCapture(video_path) + + # 确定要提取的帧数 + num_frames = caption_nums_per_video + # 计算每隔多少帧提取一次 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + step = total_frames // num_frames + + # 用于存储提取的图像的tensor + frames = torch.empty(num_frames, 3, self.resolution, self.resolution) + + # 直接读取指定帧 + for i in range(num_frames): + # 计算要提取的帧的索引 + idx = i * step + # 设置当前帧为所需的帧 + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + # 读取该帧 + ret, frame = cap.read() + if not ret: + break + + # 转换为PIL Image并进行缩放 + Image_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + frame = self.patch_resize_transform(Image_frame) + + # 将numpy数组转换为tensor并存储在frames中 + frames[i] = frame + + # 打印输出frames的形状 + # print(frames.shape) + cap.release() + return frames + + + def get_frames_from_video(self, batchsize=1, video_path=None, caption_nums_per_video = 8, ): + + # 加载视频 + video_path = video_path + vr = decord.VideoReader(video_path) + # 确定要提取的帧数 + num_frames = caption_nums_per_video + # 计算每隔多少帧提取一次 + step = len(vr) // num_frames + # 用于存储提取的图像的tensor + frames = torch.empty(num_frames, 3, self.resolution, self.resolution) + + # 从视频中提取图像 + for i in range(num_frames): + # 计算要提取的帧的索引 + idx = i * step + # 从视频中读取帧 + decord_frame = vr[idx].asnumpy() + Image_frame = Image.fromarray(decord_frame) + frame = self.patch_resize_transform(Image_frame)#.unsqueeze(0) + + # 将numpy数组转换为tensor并存储在frames中 + frames[i] = frame + + # 打印输出frames的形状 + # print(frames.shape) + vr.close() + return frames + + + def __getitem__(self, idx): + + try: + + # video_id = self.filter_train_file[idx] + # video_path = opj(self.vat_root, video_id) + video_path = self.no_caption_id_list[idx] + video_id = video_path.split('/')[-1].split('.')[0] + # 假如多个程序一起跑,其他已经生成了,就跳过 + caption_video_json = video_path.replace('.mp4', '_caption.json') + if caption_video_json in self.exist_caption_path_list: + print('parallel task has process it :{}, this is duplication!!!!!!!!!!!!!!!!!!!!!!!!!!'.format(caption_video_json)) + # return '===========', None, torch.random(8,3,self.resolution, self.resolution) + # return self.skip_sample(idx) + # if os.path.exists(caption_video_json): + # print('parallel task has process it :{}'.format(caption_video_json)) + # # return '===========', None, torch.random(8,3,self.resolution, self.resolution) + # return self.skip_sample(idx) + if not osp.exists(video_path): + print('video {} is not exists and skip this idx! '.format(video_path)) + return self.skip_sample(idx) + video_frames = self.get_frames_from_video_opencv( video_path = video_path, caption_nums_per_video = args.caption_nums_per_video) + + return video_id, video_path, video_frames + + except Exception as e: + print('Read video error in {},{} and we have skip this !, this will not cause error!'.format(idx,e)) + return self.skip_sample(idx) + + + +def synchronize(): + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +def ids_captions_save(args, video_ids, video_paths, caption_list): + for i, video_path in enumerate(video_paths): + caption_video_json = video_path.replace('.mp4', '_caption.json') + + video_16captions = caption_list[i * args.caption_nums_per_video : (i+1) * args.caption_nums_per_video] + video_caption_dict = { video_ids[i] : video_16captions } + + if osp.exists(caption_video_json): + print('{} is exist, please check your train file'.format(caption_video_json)) + continue + with open(caption_video_json, 'w', encoding = 'utf-8') as f: + json.dump(video_caption_dict, f) + print('Success :{}'.format(caption_video_json)) + + +def ofa(args): + + """https://huggingface.co/OFA-Sys/ofa-large""" + + ######################################## model start ############################# + + ckpt_dir = 'OFA-Sys/ofa-large-caption' + # ckpt_dir = 'ofa-large-caption' + tokenizer = OFATokenizer.from_pretrained(ckpt_dir) + # tokenizer = OFATokenizer.from_pretrained(ckpt_dir, use_fast=False) + model = OFAModel.from_pretrained(ckpt_dir, use_cache=True).cuda(args.local_rank) + if args.rank==0: + print('模型初始化完成') + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank) + model.eval() + if args.rank==0: + print('DDP model') + ######################################## model over ############################# + + + ######################################## dataset start ############################# + if args.rank == 0: + print('dataset 初始化') + + # 好像报错我记得 + # if not osp.exists(args.exist_caption_id_list_json): + + # command = "find {} -name '*_caption.json'".format(args.vat_root) + # output = subprocess.check_output(command, shell=True).decode().strip() + + # # 将输出结果按行拆分并保存到一个列表中 + # file_list = output.split('\n') + # # 将列表转换为JSON字符串 + # json_list = json.dumps(file_list) + + # # 将JSON字符串写入文件 + # json_path = args.exist_caption_id_list_json + # with open(json_path, 'w', encoding='utf-8') as f: + # f.write(json_list) + # print('{} is saved, nums of caption file is {}'.format(json_path,len(file_list))) + + train_dataset = my_dataset(args) + if args.rank == 0: + print('dataset_len: ',train_dataset.__len__()) + print('loading dataset is complete!') + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=args.world_size, + rank=args.rank + ) + if args.rank == 0: + print('正在同步') + synchronize() + if args.rank == 0: + print('dataloader 初始化') + dataloader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + sampler=train_sampler, + drop_last=True, + ) + ######################################## dataset over ############################# + + ######################################## ofa caption start ############################# + txt = " what does the image describe?" + inputs_ids = tokenizer([txt for i in range(args.batch_size * args.caption_nums_per_video)], return_tensors="pt").input_ids + + for index, (video_ids, video_paths, videos_frames) in enumerate(dataloader): + bs, cap_nums, c, h, w = videos_frames.shape + videos_frames = videos_frames.view(-1, c, h, w ) + + # import ipdb + # ipdb.set_trace() + gen = model.module.generate(inputs_ids.cuda(args.local_rank), patch_images=videos_frames.cuda(args.local_rank), num_beams=5, no_repeat_ngram_size=3) + caption_list = tokenizer.batch_decode(gen, skip_special_tokens=True) + ids_captions_save(args, video_ids, video_paths, caption_list) + ######################################## ofa caption over ############################# +def init_distributed_mode(args): + + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.local_rank = int(os.environ['LOCAL_RANK']) + torch.cuda.set_device(args.local_rank) + + + args.dist_backend = 'nccl' + args.dist_url = 'env://' + + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank, timeout=datetime.timedelta(seconds=5400)) + torch.distributed.barrier() + + +import utils.misc as misc + +def glob1(path = '/apdcephfs_cq3/share_1311970/A_Youtube/coco_vat', json_path = None): + import time + t1 = time.time() + import glob + file_list = glob.glob('{}/*_caption.json'.format(path), recursive=True) + json_list = json.dumps(file_list) + + print(f'caption.json sum is {len(json_list)}') + + # 将JSON字符串写入文件 + # json_path = 'coco_vat_exist_caption_id_list_03141026.json' + with open(json_path, 'w', encoding='utf-8') as f: + f.write(json_list) + print('{} is saved, nums of caption file is {}'.format(json_path,len(file_list))) + t2 = time.time() + print('!!!!!!!!{}s'.format(t2-t1)) + return file_list + +def main(args): + # args.rank = int(os.environ['RANK']) # 获取当前进程号 + # args.world_size = int(os.environ['WORLD_SIZE']) + # args.local_rank = int(os.environ['LOCAL_RANK']) + # torch.cuda.set_device(args.local_rank) + + # dist.init_process_group( + # backend='nccl',init_method='env://',world_size=args.world_size,rank=args.rank + # ) + # assert torch.distributed.is_initialized() + # dist.barrier() + misc.init_distributed_mode(args) + if args.rank == 0: + print('进程组初始化完成') + print("started") + print("started caption_json count!") + # glob1(json_path=args.exist_caption_id_list_json) # 'coco_vat_exist_caption_id_list_03141026.json' + ###########################################################3 + import time + t1=time.time() + ofa(args) + t2 = time.time() + if args.rank == 0: + print('Time : ',t2-t1,' s') + dist.destroy_process_group() # 销毁进程组 + + +def test_dataset(args): + # command = "find {} -name '*_caption.json'".format(args.vat_root) + # output = subprocess.check_output(command, shell=True).decode().strip() + + # # 将输出结果按行拆分并保存到一个列表中 + # file_list = output.split('\n') + # # 将列表转换为JSON字符串 + # json_list = json.dumps(file_list) + + # # 将JSON字符串写入文件 + # json_path = args.exist_caption_id_list_json + # with open(json_path, 'w', encoding='utf-8') as f: + # f.write(json_list) + # print('{} is saved, nums of caption file is {}'.format(json_path,len(file_list))) + + train_dataset = my_dataset(args) + loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) + + from time import time + for i, sample in enumerate(loader): + video_ids, video_paths, videos_frames = sample + + # import ipdb + # ipdb.set_trace() + print(i, video_ids, video_paths, videos_frames.shape) + + + + +if __name__ == "__main__": + import time + # time.sleep(10000) + import datetime + + # 获取当前时间 + now = datetime.datetime.now() + + # 获取当前月份 + month = now.month + + # 获取当前日期 + day = now.day + + # 获取当前小时 + hour = now.hour + + + parser = ArgumentParser() + parser.add_argument('--caption_nums_per_video', type=int, default=8, help='process rank') + parser.add_argument('--batch_size', type=int, default=2) + parser.add_argument('--train_file', type=str, required=True) + parser.add_argument('--vat_root', type=str,default=None) + parser.add_argument('--num_workers', type=int, default=1) + parser.add_argument('--resolution', type=int, default=480) + # parser.add_argument('--exist_caption_id_list_json', type=str, default=f'coco_vat_exist_caption_id_list_{month}{day}{hour}.json',help='') + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') # --dist_on_itp ddp + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--distributed', default=True, help='url used to set up distributed training') + parser.add_argument('--gpus', default=[0, 1, 2, 3], help='DP CUDA devices') + parser.add_argument('--part_index', default=1, type=int, help='used to split train_file_id into different parts, and generate caption from part_index 1 to ....') + parser.add_argument('--part_nums', default=1000, type=int, help='used to split train_file_id into different parts, and generate caption from part_index 1 to ....') + args = parser.parse_args() + + main(args) + + success_file=f"part_{args.part_index}_success" + os.system(f"touch {success_file}") + + """ + python3 -m torch.distributed.launch --nproc_per_node 1 --master_port 29504 ofa_ddp.py \ + --train_file "/apdcephfs_cq3/share_1311970/A_Youtube/coco_vat_890w_id_title_folderidx_merge.json" \ + --num_workers 8 --batch_size 1 \ + --part_index 11 \ + --part_nums 10000 + + """ + + + diff --git a/modality_generation_codes/thermal_ddp.py b/modality_generation_codes/thermal_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e8280340f3ec9d2e1e959e5bc0abdeb4242cec --- /dev/null +++ b/modality_generation_codes/thermal_ddp.py @@ -0,0 +1,509 @@ + +from __future__ import print_function +from PIL import Image +from torchvision import transforms +# from transformers import OFATokenizer, OFAModel +# from transformers.models.ofa.generate import sequence_generator # from generate import sequence_generator +import os.path +from argparse import ArgumentParser +from torch.utils import data +import json +import torch +import torch.distributed as dist +import os +import os.path as osp +from os.path import join as opj +import pandas as pd +from random import randint +import cv2 +import torch +from torch.utils.data import Dataset, DataLoader +import decord +import glob +import subprocess +import time +import numpy as np + +from utils import get_config, get_data_loader_folder, pytorch03_to_pytorch04, load_inception +from trainer import MUNIT_Trainer, UNIT_Trainer +from torch import nn +from scipy.stats import entropy +import torch.nn.functional as F +import argparse +from torch.autograd import Variable +from data import ImageFolder +import numpy as np +import torchvision.utils as vutils +try: + from itertools import izip as zip +except ImportError: # will be 3.x series + pass +import sys +import torch +import os +import os + +os.environ["HF_DATASETS_OFFLINE"] = "1" +import io + +import cv2 +import numpy as np +from decord import VideoReader, cpu + +import decord +from decord import cpu +import torch +import numpy as np +from PIL import Image +import requests +try: + from petrel_client.client import Client + petrel_backend_imported = True +except (ImportError, ModuleNotFoundError): + petrel_backend_imported = False + + +def get_video_loader(use_petrel_backend: bool = True, + enable_mc: bool = True, + conf_path: str = None): + if petrel_backend_imported and use_petrel_backend: + _client = Client(conf_path=conf_path, enable_mc=enable_mc) + else: + _client = None + + def _loader(video_path): + if _client is not None and 's3:' in video_path: + video_path = io.BytesIO(_client.get(video_path)) + + vr = VideoReader(video_path, num_threads=1, ctx=cpu(0)) + return vr + + return _loader + + +class my_dataset(Dataset): + def __init__(self, args): + super().__init__() + self.args = args + self.shuffle = True + self.resolution = args.resolution # 对于动态大小视频无用 + self.video_loader = get_video_loader() + if args.train_file.endswith('.csv'): + self.train_file = pd.read_csv(args.train_file) + elif args.train_file.endswith('.json'): + # coco_vat_vat0_11_all_id_rootfolder_clsidx_spacy.json + # 格式: id : { 'idx_list' : [0], 'root_folder' : 'coco_vat_9' } + + if hasattr(args, 'part_nums') and args.part_nums > 1: + self.part_nums = args.part_nums + else: + self.part_nums = 100000 + self.part_index = args.part_index + t1 = time.time() + with open(args.train_file, 'r', encoding='utf-8') as f: + self.train_file = json.load(f) + if type(self.train_file) is str: + self.train_file = json.loads(self.train_file) + + self.id_list = list(self.train_file.keys()) + # ============================= + # obtain subset of self.id_list + self.id_list = self.id_list[self.part_nums * (self.part_index - 1):self.part_nums * self.part_index] + + print( + f'Nums of train_file is {len(self.id_list)},part_index:{self.part_index}, first:{self.id_list[0]}') + self.no_caption_id_list = [] + for idx, id in enumerate(self.id_list): + caption_json = osp.join('/apdcephfs_cq3/share_1311970/A_Youtube', + self.train_file[id]['root_folder'], f'{id}_thermal_folder') + mp4_path = osp.join('/apdcephfs_cq3/share_1311970/A_Youtube', self.train_file[id]['root_folder'], + f'{id}.mp4') + if not os.path.exists(caption_json) and os.path.exists(mp4_path): + self.no_caption_id_list.append(mp4_path) + # else: + # print(f'{caption_json} is exist!') + if idx % 10000 == 0: + print(f'Time_cost:{time.time() - t1}s, idx:{idx}, caption_json:{caption_json}') + try: + print(f'Nums of no_thermal_folder_id_list is {len(self.no_caption_id_list)}, first:{self.no_caption_id_list[0]}') + except: + print(f'Nums of no_thermal_folder_id_list is {len(self.no_caption_id_list)}') + + t2 = time.time() + print(f'Time cost:{t2 - t1}s') + # DPT + # self.patch_resize_transform = DPTImageProcessor.from_pretrained("Intel/dpt-large", cache_dir= args.weights_folder) + # glpn + self.patch_resize_transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5)), + transforms.Resize((400, 640)) + ]) + print('Dataset nums is {}'.format(self.__len__())) + time.sleep(10) + + def __len__(self): + return len(self.no_caption_id_list) + + def random_sample(self): + return self.__getitem__(randint(0, self.__len__() - 1)) + + def sequential_sample(self, idx): + if idx >= self.__len__() - 1: + return self.__getitem__(0) + return self.__getitem__(idx + 1) + + def skip_sample(self, idx): + if self.shuffle: + return self.random_sample() + return self.sequential_sample(idx=idx) + + def get_frames_from_video_opencv(self, batchsize=1, video_path=None, caption_nums_per_video=8): + # 加载视频 + video_path = video_path + cap = cv2.VideoCapture(video_path) + + # 确定要提取的帧数 + num_frames = caption_nums_per_video + # 计算每隔多少帧提取一次 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + step = total_frames // num_frames + + # 用于存储提取的图像的tensor + # frames = torch.empty(num_frames, 3, frame_height, frame_width) + frames = [] + height = width = 0 + # frame_idx_list=[] 把视频的帧序号存储下来 + # 直接读取指定帧 + for i in range(num_frames): + # 计算要提取的帧的索引 + idx = i * step + # frame_idx_list.append(idx) + # 设置当前帧为所需的帧 + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + # 读取该帧 + ret, frame = cap.read() + if i == 0: + height, width = frame.shape[:2] + if height < width: + new_height = 256 + new_width = int(width * (new_height / height)) + else: + new_width = 256 + new_height = int(height * (new_width / width)) + if not ret: + break + # 转换为PIL Image并进行缩放 + Image_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + frame = self.patch_resize_transform(Image_frame).unsqueeze(0).unsqueeze(0) + # 将numpy数组转换为tensor并存储在frames中 + # frames[i] = frame + # 将numpy数组转换为tensor并存储在frames中 + frames.append(frame) + + # vr.close() + frames = torch.cat(frames, 1).squeeze(0) + return frames, new_height, new_width + + + def get_frames_from_video_decord(self, batchsize=1, video_path=None, caption_nums_per_video=8): + # 加载视频 + video_path = video_path + vr = self.video_loader(video_path) + frame_width, frame_height = vr[0].shape[1], vr[0].shape[0] + + # 确定要提取的帧数 + num_frames = caption_nums_per_video + # 计算每隔多少帧提取一次 + total_frames = len(vr) + step = total_frames // num_frames + + # 用于存储提取的图像的tensor + frames = [] + height = width = 0 + + # 直接读取指定帧 + for i in range(num_frames): + # 计算要提取的帧的索引 + idx = i * step + # 读取该帧 + frame = vr[idx].asnumpy() + if i == 0: + height, width = frame.shape[:2] + if height < width: + new_height = 256 + new_width = int(width * (new_height / height)) + else: + new_width = 256 + new_height = int(height * (new_width / width)) + + # 转换为PIL Image并进行缩放 + Image_frame = Image.fromarray(frame) + frame = self.patch_resize_transform(Image_frame).unsqueeze(0).unsqueeze(0) + + # 将numpy数组转换为tensor并存储在frames中 + frames.append(frame) + + # vr.close() + frames = torch.cat(frames, 1).squeeze(0) + return frames, new_height, new_width + + def __getitem__(self, idx): + + try: + + # video_id = self.filter_train_file[idx] + # video_path = opj(self.vat_root, video_id) + video_path = self.no_caption_id_list[idx] + video_id = video_path.split('/')[-1].split('.')[0] + # 假如多个程序一起跑,其他已经生成了,就跳过 + caption_video_json = video_path.replace('.mp4', '_thermal_folder') + if os.path.exists(caption_video_json): + print('parallel task has process it :{}'.format(caption_video_json)) + # return '===========', None, torch.random(8,3,self.resolution, self.resolution) + return self.skip_sample(idx) + if not osp.exists(video_path): + print('video {} is not exists and skip this idx! '.format(video_path)) + return self.skip_sample(idx) + # video_frames, height, width = self.get_frames_from_video_opencv( video_path = video_path, caption_nums_per_video = args.caption_nums_per_video) + video_frames, height, width = self.get_frames_from_video_opencv(video_path=video_path, + caption_nums_per_video=args.caption_nums_per_video) + return video_id, video_path, video_frames, height, width + + except Exception as e: + print('Read video error in {},{} and we have skip this !, this will not cause error!'.format(idx, e)) + return self.skip_sample(idx) + + +def synchronize(): + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + + +def thermal_estimation(args): + ######################################## model start ############################# + """https://huggingface.co/docs/transformers/main/en/model_doc/dpt""" + + + config = 'configs/tir2rgb_folder.yaml' + a2b = 0 + checkpoint = './translation_weights.pt' + output_path = '.' + num_style = 1 + config = get_config(config) + config['vgg_model_path'] = output_path + style_dim = config['gen']['style_dim'] + + trainer = MUNIT_Trainer(config) + + state_dict = torch.load(checkpoint) + trainer.gen_a.load_state_dict(state_dict['a']) + trainer.gen_b.load_state_dict(state_dict['b']) + + trainer.cuda(args.local_rank) + trainer.train() + + if args.rank == 0: + print('模型初始化完成') + trainer = torch.nn.parallel.DistributedDataParallel(trainer, device_ids=[args.local_rank], + output_device=args.local_rank) + encode = trainer.module.gen_a.encode if a2b else trainer.module.gen_b.encode # encode function + decode = trainer.module.gen_b.decode if a2b else trainer.module.gen_a.decode # decode function + + + + if args.rank == 0: + print('DDP model') + ######################################## model over ############################# + + ######################################## dataset start ############################# + if args.rank == 0: + print('dataset 初始化') + + train_dataset = my_dataset(args) + if args.rank == 0: + print('dataset_len: ', train_dataset.__len__()) + print('loading dataset is complete!') + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=args.world_size, + rank=args.rank + ) + if args.rank == 0: + print('正在同步') + synchronize() + if args.rank == 0: + print('dataloader 初始化') + dataloader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + sampler=train_sampler, + drop_last=True + ) + ######################################## dataset over ############################# + ######################################## thermal_estimation start ############################# + + for index, (video_ids, video_paths, videos_frames_, h_list, w_list) in enumerate(dataloader): + # print(videos_frames_.shape) + bs, cap_nums, c, h, w = videos_frames_.shape + videos_frames = videos_frames_.reshape(-1, c, h, w) + images = Variable(videos_frames.cuda(args.local_rank), volatile=True) + torch.cuda.empty_cache() + try: + with torch.no_grad(): + predicted_thermal = [] + for i in range(bs*cap_nums): + # print('images[i]', images[i].unsqueeze(0).shape) + content, _ = encode(images[i].unsqueeze(0)) + # print('content', content.shape) + # style = style_fixed if opts.synchronized else Variable(torch.randn(num_style, style_dim, 1, 1).cuda(), volatile=False) + style = Variable(torch.randn(num_style, style_dim, 1, 1).cuda(args.local_rank), volatile=False) + s = style[0].unsqueeze(0) + # print('s', s.shape) + outputs = decode(content, s) + outputs = (outputs + 1) / 2. + # print('outputs', outputs.shape) + + + predicted_thermal.append(outputs) + predicted_thermal = torch.cat(predicted_thermal, dim=0) # (bs*cap_nums, 3, h, w) + predicted_thermal = predicted_thermal.view(bs, cap_nums, c, h, w) + # print(f'predicted_thermal.shape:{predicted_thermal.shape}') + # interpolate to original size + for bs_idx, sample in enumerate(predicted_thermal): + # import ipdb + # ipdb.set_trace() + pic_folder = video_paths[bs_idx].replace('.mp4', '_thermal_folder') + os.makedirs(pic_folder, exist_ok=True) + for frame_idx, frame in enumerate(predicted_thermal[bs_idx]): + # print(frame.shape, h_list, w_list) + prediction = torch.nn.functional.interpolate( + frame.unsqueeze(0), # torch.Size([1, 3, 400, 640]) + size=(h_list[bs_idx], w_list[bs_idx]), + mode="bicubic", + align_corners=False, + ) # torch.Size([1, 1, h=480, w=640]) + # print('prediction.shape:{prediction.shape}') + + vutils.save_image(prediction.data, f"{pic_folder}/{frame_idx}.jpg", padding=0, normalize=True) + + + print(f'{pic_folder} is succeed!') + # sys.exit(0) + del videos_frames, outputs, predicted_thermal + except Exception as e: + print(f'Error:{e}!') + del videos_frames + ######################################## thermal_estimation over ############################# + + +def init_distributed_mode(args): + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.local_rank = int(os.environ['LOCAL_RANK']) + torch.cuda.set_device(args.local_rank) + + args.dist_backend = 'nccl' + args.dist_url = 'env://' + + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank, + timeout=datetime.timedelta(seconds=5400)) + torch.distributed.barrier() + + +import misc + + +def main(args): + misc.init_distributed_mode(args) + if args.rank == 0: + print('进程组初始化完成') + print("started") + print("started caption_json count!") + # glob1(json_path=args.exist_caption_id_list_json) # 'coco_vat_exist_caption_id_list_03141026.json' + ###########################################################3 + import time + t1 = time.time() + thermal_estimation(args) + t2 = time.time() + if args.rank == 0: + print('Time : ', t2 - t1, ' s') + dist.destroy_process_group() # 销毁进程组 + + +def test_dataset(args): + train_dataset = my_dataset(args) + loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) + from time import time + for i, sample in enumerate(loader): + video_ids, video_paths, videos_frames, h, w = sample + print(i, video_ids, video_paths, videos_frames.shape, h, w) + + +if __name__ == "__main__": + import time + # time.sleep(10000) + import datetime + + # 获取当前时间 + now = datetime.datetime.now() + # 获取当前月份 + month = now.month + # 获取当前日期 + day = now.day + # 获取当前小时 + hour = now.hour + + parser = ArgumentParser() + parser.add_argument('--caption_nums_per_video', type=int, default=8, help='process rank') + parser.add_argument('--batch_size', type=int, default=2) + parser.add_argument('--train_file', type=str, required=True) + parser.add_argument('--vat_root', type=str, default=None) + parser.add_argument('--num_workers', type=int, default=1) + parser.add_argument('--resolution', type=int, default=480) + # parser.add_argument('--exist_caption_id_list_json', type=str, default=f'coco_vat_exist_caption_id_list_{month}{day}{hour}.json',help='') + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') # --dist_on_itp ddp + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--distributed', default=True, help='url used to set up distributed training') + parser.add_argument('--gpus', default=[0, 1, 2, 3], help='DP CUDA devices') + parser.add_argument('--part_index', default=1, type=int, + help='used to split train_file_id into different parts, and generate caption from part_index 1 to ....') + parser.add_argument('--part_nums', default=1000, type=int, + help='used to split train_file_id into different parts, and generate caption from part_index 1 to ....') + parser.add_argument('--weights_folder', type=str, default='/apdcephfs_cq3/share_1311970/A_thermal/') + args = parser.parse_args() + # test_dataset(args) + # import ipdb + # ipdb.set_trace() + main(args) + synchronize() + # success_file=f"part_{args.part_index}_success" + success_file = f"/apdcephfs_cq3/share_1311970/A_thermal/part_{args.part_index}_success" + os.system(f"touch {success_file}") + + """ + cd /apdcephfs_cq3/share_1311970/A_thermal/sRGB-TIR + source /apdcephfs_cq3/share_1311970/lb/miniconda3/etc/profile.d/conda.sh + conda activate /apdcephfs_cq3/share_1311970/lb/miniconda3/envs/pytorch1.12.1 + export CUDA_VISIBLE_DEVICES=1 + HF_DATASETS_OFFLINE=1 python3 -m torch.distributed.launch --nproc_per_node 1 --master_port 29504 thermal_ddp.py \ + --train_file /apdcephfs_cq3/share_1311970/A_Youtube/coco_vat_vat0_11_all_id_rootfolder_clsidx_spacy.json \ + --num_workers 8 --batch_size 1 \ + --part_index 91 \ + --part_nums 10000 \ + --resolution 0 \ + --caption_nums_per_video 8 + + """ diff --git a/model/base_model.py b/model/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9cbe6c1d7aef31b4c3991d1d77da5a2bd55feee9 --- /dev/null +++ b/model/base_model.py @@ -0,0 +1,357 @@ +import sys + +import torch +from einops import rearrange +from typing import Optional, Tuple, Union + +from torch import nn +from transformers import CLIPModel as HFCLIPModel, CLIPVisionConfig +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.clip.modeling_clip import CLIP_VISION_INPUTS_DOCSTRING +from transformers.utils import replace_return_docstrings, add_start_docstrings_to_model_forward + + +# class VT_CLIP(nn.Module): +# output_dict: torch.jit.Final[bool] +# +# def __init__( +# self, +# embed_dim: int, +# vision_cfg: CLIPVisionCfg, +# text_cfg: CLIPTextCfg, +# quick_gelu: bool = False, +# cast_dtype: Optional[torch.dtype] = None, +# output_dict: bool = False, +# ): +# super().__init__() +# self.output_dict = output_dict +# self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) +# +# text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) +# self.transformer = text.transformer +# self.context_length = text.context_length +# self.vocab_size = text.vocab_size +# self.token_embedding = text.token_embedding +# self.positional_embedding = text.positional_embedding +# self.ln_final = text.ln_final +# self.text_projection = text.text_projection +# self.register_buffer('attn_mask', text.attn_mask, persistent=False) +# +# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) +# +# +# +# def unlock_time_attn(self): +# for name, param in self.named_parameters(): +# if 'time' in name: +# param.requires_grad = True +# +# def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): +# # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 +# self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) +# +# def lock_text_tower(self, unlocked_layers=0, freeze_layer_norm=False): +# for param in self.transformer.parameters(): +# param.requires_grad = False +# for param in self.token_embedding.parameters(): +# param.requires_grad = False +# for param in self.ln_final.parameters(): +# param.requires_grad = False +# self.positional_embedding.requires_grad = False +# self.text_projection.requires_grad = False +# +# if unlocked_layers != 0: +# groups = [ +# [ +# self.token_embedding, +# self.positional_embedding, +# ], +# *self.transformer.resblocks[:-1], +# [ +# self.transformer.resblocks[-1], +# self.ln_final, +# ], +# self.text_projection, +# ] +# +# def _unlock(x): +# if isinstance(x, Sequence): +# for g in x: +# _unlock(g) +# else: +# if isinstance(x, torch.nn.Parameter): +# x.requires_grad = True +# else: +# for p in x.parameters(): +# p.requires_grad = True +# +# _unlock(groups[-unlocked_layers:]) +# +# @torch.jit.ignore +# def set_grad_checkpointing(self, enable=True): +# self.visual.set_grad_checkpointing(enable) +# self.transformer.grad_checkpointing = enable +# +# def encode_image(self, image, normalize: bool = False): +# features = self.visual(image) +# return F.normalize(features, dim=-1) if normalize else features +# +# def encode_text(self, text, normalize: bool = False): +# cast_dtype = self.transformer.get_cast_dtype() +# +# x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] +# +# x = x + self.positional_embedding.to(cast_dtype) +# x = x.permute(1, 0, 2) # NLD -> LND +# x = self.transformer(x, attn_mask=self.attn_mask) +# x = x.permute(1, 0, 2) # LND -> NLD +# x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] +# # take features from the eot embedding (eot_token is the highest number in each sequence) +# x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection +# return F.normalize(x, dim=-1) if normalize else x +# +# def forward( +# self, +# image: Optional[torch.Tensor] = None, +# text: Optional[torch.Tensor] = None, +# ): +# image_features = self.encode_image(image, normalize=True) if image is not None else None +# text_features = self.encode_text(text, normalize=True) if text is not None else None +# if self.output_dict: +# return { +# "image_features": image_features, +# "text_features": text_features, +# "logit_scale": self.logit_scale.exp() +# } +# return image_features, text_features, self.logit_scale.exp() +from model.process_clip import get_global_value, set_global_value + + +def SET_GLOBAL_VALUE(k, v): + set_global_value(k, v) + +class CLIPVisionEmbeddings(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + # (b t) c h w + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) # b hw c + return embeddings + +class CLIPVisionEmbeddings3D(nn.Module): + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.num_frames = config.num_frames + self.tube_size = config.tube_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) + + def expand3d(self): + + state_dict = self.patch_embedding.state_dict() + state_dict_expand = state_dict['weight'].unsqueeze(2) + device, dtype = state_dict_expand.device, state_dict_expand.dtype + # print(device, dtype) + + zero = torch.zeros_like(state_dict_expand).to(device=device, dtype=dtype) + state_dict_expand3d = torch.cat([state_dict_expand] + (self.tube_size-1)*[zero], dim=2) + + # state_dict_expand3d = torch.cat([state_dict_expand / self.tube_size] * self.tube_size, dim=2) + + patch_embedding = nn.Conv3d( + in_channels=self.patch_embedding.in_channels, + out_channels=self.embed_dim, + kernel_size=(self.tube_size, self.patch_size, self.patch_size), + stride=(self.tube_size, self.patch_size, self.patch_size), + bias=False, + ).to(device=device, dtype=dtype) + patch_embedding.load_state_dict({'weight': state_dict_expand3d}) + self.patch_embedding = patch_embedding + + + class_embedding = nn.Parameter(self.class_embedding.data.repeat(self.num_frames // self.tube_size, 1)).to(device=device, dtype=dtype) + self.class_embedding = class_embedding + + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + # (b t) c h w + batch_size = pixel_values.shape[0] // self.num_frames + pixel_values = rearrange(pixel_values, '(b t) c h w -> b c t h w', b=batch_size, t=self.num_frames) + # print('pixel_values', pixel_values.shape) + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, t, grid, grid] + # print('patch_embeds', patch_embeds.shape) + # SET_GLOBAL_VALUE('NUM_FRAMES', patch_embeds.shape[2]) + patch_embeds = rearrange(patch_embeds, 'b c t h w -> b t (h w) c') + + class_embeds = self.class_embedding.unsqueeze(1).unsqueeze(0).repeat(batch_size, 1, 1, 1) # b t 1 c + # print('class_embeds', class_embeds.device, class_embeds.dtype) + # print('patch_embeds', patch_embeds.device, patch_embeds.dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=2) # b t hw+1 c + embeddings = embeddings + self.position_embedding(self.position_ids) + embeddings = rearrange(embeddings, 'b t hw_1 c -> (b t) hw_1 c') + return embeddings + +class CLIPModel(HFCLIPModel): + def __init__(self, config, num_frames, add_time_attn, vl_new, tube_size): + super(CLIPModel, self).__init__(config) + config.vision_config.num_frames = num_frames + config.vision_config.tube_size = tube_size + if add_time_attn: + if vl_new: + self.vision_model.embeddings = CLIPVisionEmbeddings3D(config.vision_config) + else: + self.vision_model.embeddings = CLIPVisionEmbeddings(config.vision_config) + self.T = config.vision_config.num_frames // config.vision_config.tube_size + self.vision_model.forward = self.vision_model_forward + + @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) + def vision_model_forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.vision_model.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.vision_model.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.vision_model.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + if len(pixel_values.shape) == 7: + b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape + # print(pixel_values.shape) + B = b_new * pair_new * bs_new + pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new) + + elif len(pixel_values.shape) == 5: + B, _, T, _, _ = pixel_values.shape + # print(pixel_values.shape) + pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w') + else: + # print(pixel_values.shape) + B, _, _, _ = pixel_values.shape + T = 1 + hidden_states = self.vision_model.embeddings(pixel_values) + # print('hidden_states', hidden_states.shape) + # + # if self.temporal_embedding is not None and get_global_value()['NUM_FRAMES'] != 1: + # n = hidden_states.shape[1] + # hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=T) + # hidden_states = hidden_states + self.temporal_embedding[:, :T, :] + # hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + T = self.T + # print('B.shape, T.shape', B.shape, T.shape) + hidden_states = self.vision_model.patch_dropout(hidden_states, B, T) + # print('patch_dropout', hidden_states.shape) + hidden_states = self.vision_model.pre_layrnorm(hidden_states) + + encoder_outputs = self.vision_model.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.vision_model.post_layernorm(pooled_output) + + pooled_output = pooled_output.reshape(B, T, -1).mean(1) + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def encode_image(self, image, normalize: bool = False): + vision_outputs = self.vision_model( + pixel_values=image, + return_dict=True, + ) + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + return image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) if normalize else image_embeds + + def encode_text(self, input_ids, attention_mask, normalize: bool = False): + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + ) + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + return text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) if normalize else text_embeds + + + def forward( + self, + image=None, + input_ids=None, attention_mask=None + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(input_ids, attention_mask, normalize=True) if input_ids is not None else None + # if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + # return image_features, text_features, self.logit_scale.exp() + + diff --git a/model/build_model.py b/model/build_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7551fdb4992452fa0669da39e149788539023781 --- /dev/null +++ b/model/build_model.py @@ -0,0 +1,197 @@ +import logging +import argparse +import os.path +import numpy as np +import torch +from torch import nn +from transformers import AutoConfig, CLIPPreTrainedModel + + +from model.base_model import CLIPModel +from model.process_clip import add_time_attn_block, convert_model_to_lora, set_global_value, resize_pos +from open_clip import convert_weights_to_lp +from open_clip.transformer import PatchDropout +from training.distributed import is_master + + +def SET_GLOBAL_VALUE(k, v): + set_global_value(k, v) + +def create_vat_model(args): + + config = AutoConfig.from_pretrained(args.model, cache_dir=args.cache_dir) + model = CLIPModel(config, args.num_frames, args.add_time_attn, args.clip_type=='vl_new', args.tube_size) + + model.vision_model.patch_dropout = PatchDropout(args.force_patch_dropout) + + device = args.device + precision = args.precision + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device) + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device, dtype=dtype) + else: + model.to(device=device) + + if args.pretrained: + try: + args.pretrained = os.path.join(args.cache_dir, args.pretrained) + if is_master(args): + logging.info(f'Loading pretrained {args.model} weights ({args.pretrained}).') + # incompatible_keys = load_checkpoint(model, pretrained, strict=False) + ckpt = torch.load(args.pretrained, map_location='cpu') + incompatible_keys = model.load_state_dict(ckpt, strict=False if args.add_time_attn else True) + if is_master(args): + logging.info(incompatible_keys) + except Exception as e: + if is_master(args): + logging.info(f"Failed loading pretrained model with {e}") + else: + if is_master(args): + logging.info(f"No pretrained model to load in \'{args.pretrained}\'") + + if args.add_time_attn: + add_time_attn_block(model.vision_model.encoder, device=device) + if is_master(args): + logging.info(f'Convert spatial attention to time attention pretrained.') + + if args.clip_type == 'al': + resize_pos(model.vision_model.embeddings, args) + if is_master(args): + logging.info(f'Resize to position embedding successfully.') + + if args.clip_type == 'vl_new': + model.vision_model.embeddings.expand3d() + + if args.init_temp != 0: + with torch.no_grad(): + model.logit_scale.fill_(np.log(1 / float(args.init_temp))) + if is_master(args): + logging.info(f'Reset logit scale to {args.init_temp} (log-scale) and trainable {args.learn_temp}.') + + if args.convert_to_lora: + convert_model_to_lora(args, model) + if is_master(args): + logging.info(f"Successfuly convert model to lora style.") + + # if output_dict and hasattr(model, "output_dict"): + # model.output_dict = True + + return model + + +if __name__ == '__main__': + MODEL_DICT = {"ViT-L-14": "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K", + "ViT-H-14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"} + CHECKPOINT_DICT = {"ViT-L-14": "models--laion--CLIP-ViT-L-14-DataComp.XL-s13B-b90K/snapshots/84c9828e63dc9a9351d1fe637c346d4c1c4db341/pytorch_model.bin", + "ViT-H-14": "models--laion--CLIP-ViT-H-14-laion2B-s32B-b79K/snapshots/94a64189c3535c1cb44acfcccd7b0908c1c8eb23/pytorch_model.bin"} + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.pretrained = True + args.model = MODEL_DICT["ViT-L-14"] + args.pretrained = CHECKPOINT_DICT["ViT-L-14"] + args.cache_dir = 'D:\Omni-modal-valdt-1kw' + args.device = 'cpu' + args.precision = None + args.lock_text = True + args.lock_image = True + args.init_temp = 0 + args.force_patch_dropout = 0.5 + args.add_time_attn = True + args.convert_to_lora = True + args.lora_r = 16 + args.lora_alpha = 16 + args.lora_dropout = 0.0 # 0.1? + args.num_frames = 8 + args.tube_size = 1 + args.clip_type = 'vl_new' + args.num_mel_bins = 128 + args.target_length = 1024 + args.audio_sample_rate = 16000 + args.audio_mean = 1 + args.audio_std = 1 + args.rank = 0 + + # SET_GLOBAL_VALUE('PATCH_DROPOUT', args.force_patch_dropout) + # SET_GLOBAL_VALUE('NUM_FRAMES', args.num_frames) + + model = create_vat_model(args) + + + '''方法1,自定义函数 参考自 https://blog.csdn.net/qq_33757398/article/details/109210240''' + + + def model_structure(model): + blank = ' ' + print('-' * 150) + print('|' + ' ' * 44 + 'weight name' + ' ' * 45 + '|' \ + + ' ' * 10 + 'weight shape' + ' ' * 10 + '|' \ + + ' ' * 3 + 'number' + ' ' * 3 + '|') + print('-' * 150) + num_para = 0 + type_size = 1 # 如果是浮点数就是4 + + for index, (key, w_variable) in enumerate(model.named_parameters()): + if len(key) <= 100: + key = key + (100 - len(key)) * blank + shape = str(w_variable.shape) + if len(shape) <= 30: + shape = shape + (30 - len(shape)) * blank + each_para = 1 + for k in w_variable.shape: + each_para *= k + num_para += each_para + str_num = str(each_para) + if len(str_num) <= 10: + str_num = str_num + (10 - len(str_num)) * blank + + print('| {} | {} | {} |'.format(key, shape, str_num)) + print('-' * 150) + print('The total number of parameters: ' + str(num_para)) + print('The parameters of Model {}: {:4f}M'.format(model._get_name(), num_para * type_size / 1000 / 1000)) + print('-' * 150) + + + model_structure(model) + # model_structure(model.vision_model) + # model_structure(model.text_model) + + + # model.lock_image_tower(unlocked_groups=1) + # model.lock_text_tower(unlocked_layers=0) + # model.unlock_time_attn() + + if args.lock_image: + # if args.clip_type == 'al' or args.clip_type == 'dl': + # for param in model.vision_model.embeddings.parameters(): + # param.requires_grad = True + # for param in model.vision_model.pre_layrnorm.parameters(): + # param.requires_grad = True + # else: + for param in model.vision_model.embeddings.parameters(): + param.requires_grad = False + for param in model.vision_model.pre_layrnorm.parameters(): + param.requires_grad = False + for param in model.vision_model.embeddings.position_embedding.parameters(): + param.requires_grad = False + model.vision_model.embeddings.class_embedding.requires_grad = True + + + if args.lock_text: + for param in model.text_model.parameters(): + param.requires_grad = False + for param in model.text_projection.parameters(): + param.requires_grad = False + + + for n, p in model.named_parameters(): + # if p.requires_grad: + print(n, '--->', p.requires_grad) + b, c, t, h, w = 2, 3, args.num_frames, 224, 224 + x = torch.randn(b, c, t, h, w) + y = model(image=x) + print() \ No newline at end of file diff --git a/model/languagebind.py b/model/languagebind.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4091fef20bd6623634e6693104e18fab33222e --- /dev/null +++ b/model/languagebind.py @@ -0,0 +1,140 @@ + +import gradio as gr +import argparse +import numpy as np +import torch +from torch import nn + +from data.process_image import load_and_transform_image, get_image_transform +from main import SET_GLOBAL_VALUE +from model.build_model import create_vat_model +from data.process_audio import load_and_transform_audio, get_audio_transform +from data.process_video import load_and_transform_video, get_video_transform +from data.process_depth import load_and_transform_depth, get_depth_transform +from data.process_thermal import load_and_transform_thermal, get_thermal_transform +from data.process_text import load_and_transform_text +from open_clip import get_tokenizer +from open_clip.factory import HF_HUB_PREFIX + + + + + + +class LanguageBind(nn.Module): + def __init__(self, args, no_temp=False): + super(LanguageBind, self).__init__() + self.no_temp = no_temp + MODEL_DICT = {"ViT-L-14": "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K", + "ViT-H-14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"} + args.pretrained = False + args.model = MODEL_DICT["ViT-L-14"] + args.cache_dir = 'D:/Omni-modal-valdt-audio' + args.video_decode_backend = 'decord' + # args.device = 'cpu' + args.device = 'cuda:0' + device = torch.device(args.device) + args.precision = None + args.init_temp = 0 + args.force_patch_dropout = 0.0 + args.add_time_attn = False + args.convert_to_lora = True + args.lora_r = 2 + args.lora_alpha = 16 + args.lora_dropout = 0.0 # 0.1? + args.num_frames = 8 + args.clip_type = 'vl' + args.num_mel_bins = 1008 + args.target_length = 112 + args.audio_sample_rate = 16000 + args.audio_mean = 4.5689974 + args.audio_std = -4.2677393 + args.max_depth = 10 + args.image_size = 224 + args.rank = 0 + SET_GLOBAL_VALUE('PATCH_DROPOUT', args.force_patch_dropout) + SET_GLOBAL_VALUE('NUM_FRAMES', args.num_frames) + args.clip_type = ['il', 'vl', 'al', 'dl', 'tl'] + + + temp_clip_type = args.clip_type + self.modality_encoder = {} + self.modality_proj = {} + self.modality_scale = {} + for c in temp_clip_type: + args.clip_type = c + if c == 'il': + args.convert_to_lora = False + model = create_vat_model(args) + args.convert_to_lora = True + elif c == 'vl': + args.lora_r = 64 + args.add_time_attn = True + model = create_vat_model(args) + args.add_time_attn = False + args.lora_r = 2 + elif c == 'al': + args.lora_r = 8 + model = create_vat_model(args) + args.lora_r = 2 + else: + model = create_vat_model(args) + ''' + state_dict = torch.load(f'model_zoo/{c}.pt', map_location='cpu') + if state_dict.get('state_dict', None) is not None: + state_dict = state_dict['state_dict'] + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + print(f'load {c}, {msg}') + ''' + if c == 'vl': + self.modality_encoder['video'] = model.vision_model + self.modality_proj['video'] = model.visual_projection + self.modality_scale['video'] = model.logit_scale + elif c == 'al': + self.modality_encoder['audio'] = model.vision_model + self.modality_proj['audio'] = model.visual_projection + self.modality_scale['audio'] = model.logit_scale + elif c == 'dl': + self.modality_encoder['depth'] = model.vision_model + self.modality_proj['depth'] = model.visual_projection + self.modality_scale['depth'] = model.logit_scale + elif c == 'tl': + self.modality_encoder['thermal'] = model.vision_model + self.modality_proj['thermal'] = model.visual_projection + self.modality_scale['thermal'] = model.logit_scale + elif c == 'il': + self.modality_encoder['image'] = model.vision_model + self.modality_proj['image'] = model.visual_projection + self.modality_scale['image'] = model.logit_scale + else: + raise NameError(f'No clip_type of {c}') + self.modality_encoder['language'] = model.text_model + self.modality_proj['language'] = model.text_projection + + self.modality_encoder = nn.ModuleDict(self.modality_encoder) + self.modality_proj = nn.ModuleDict(self.modality_proj) + + def forward(self, inputs): + outputs = {} + for key, value in inputs.items(): + value = self.modality_encoder[key](**value)[1] + value = self.modality_proj[key](value) + value = value / value.norm(p=2, dim=-1, keepdim=True) + if not self.no_temp: + if key != 'language': + value = value * self.modality_scale[key].exp() + outputs[key] = value + return outputs + + + +def stack_dict(x, device): + if len(x) == 0: + return None + out_dict = {} + keys = list(x[0].keys()) + for key in keys: + out_dict[key] = torch.stack([i[key] for i in x]).to(device) + return out_dict \ No newline at end of file diff --git a/model/process_clip.py b/model/process_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..931e4bccc66ade2a6dae4e625ed27dbcc0d5b5ff --- /dev/null +++ b/model/process_clip.py @@ -0,0 +1,720 @@ +import logging +import math +from typing import Optional, Tuple +from einops import rearrange +from peft import LoraConfig, get_peft_model +from transformers import CLIPConfig +from transformers.models.clip.modeling_clip import CLIPEncoderLayer as SpatialCLIPEncoderLayer, CLIPAttention, CLIPMLP +import torch +from torch import nn +from torch.nn import functional as F + +from training.distributed import is_master + +aaa = {'NUM_FRAMES': 1, 'PATCH_DROPOUT': 0.0} + +def set_global_value(k, v): + global aaa + aaa[k] = v + +def get_global_value(): + global aaa + return aaa + +# @dataclass +# class CLIPVisionCfg: +# layers: Union[Tuple[int, int, int, int], int] = 12 +# width: int = 768 +# head_width: int = 64 +# mlp_ratio: float = 4.0 +# patch_size: int = 16 +# image_size: Union[Tuple[int, int], int] = 224 +# cast_dtype: str = None +# num_frames: int = 2 +# +# ls_init_value: Optional[float] = None # layer scale initial value +# patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results +# input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design +# global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) +# attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer +# n_queries: int = 256 # n_queries for attentional pooler +# attn_pooler_heads: int = 8 # n heads for attentional_pooling +# output_tokens: bool = False +# +# timm_model_name: str = None # a valid model name overrides layers, width, patch_size +# timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model +# timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') +# timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') +# timm_proj_bias: bool = False # enable bias final projection +# timm_drop: float = 0. # head dropout +# timm_drop_path: Optional[float] = None # backbone stochastic depth + +# class Video_VisionTransformer(nn.Module): +# output_tokens: torch.jit.Final[bool] +# +# def __init__( +# self, +# num_frames: int, +# image_size: int, +# patch_size: int, +# width: int, +# layers: int, +# heads: int, +# mlp_ratio: float, +# ls_init_value: float = None, +# global_average_pool: bool = False, +# attentional_pool: bool = False, +# n_queries: int = 256, +# attn_pooler_heads: int = 8, +# output_dim: int = 512, +# patch_dropout: float = 0., +# input_patchnorm: bool = False, +# act_layer: Callable = nn.GELU, +# norm_layer: Callable = LayerNorm, +# output_tokens: bool = False +# ): +# super().__init__() +# self.output_tokens = output_tokens +# image_height, image_width = self.image_size = to_2tuple(image_size) +# patch_height, patch_width = self.patch_size = to_2tuple(patch_size) +# self.grid_size = (image_height // patch_height, image_width // patch_width) +# self.output_dim = output_dim +# +# # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 +# self.input_patchnorm = input_patchnorm +# +# if input_patchnorm: +# patch_input_dim = patch_height * patch_width * 3 +# self.patchnorm_pre_ln = LayerNorm(patch_input_dim) +# self.conv1 = nn.Linear(patch_input_dim, width) +# else: +# self.patchnorm_pre_ln = nn.Identity() +# self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, +# bias=False) +# +# # class embeddings and positional embeddings +# self.scale = scale = width ** -0.5 +# self.class_embedding = nn.Parameter(scale * torch.randn(width)) +# self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) +# +# self.temporal_embedding = nn.Parameter(torch.zeros(1, num_frames, width)) +# # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn +# self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() +# +# self.ln_pre = norm_layer(width) +# self.transformer = Transformer( +# width, +# layers, +# heads, +# mlp_ratio, +# ls_init_value=ls_init_value, +# act_layer=act_layer, +# norm_layer=norm_layer, +# ) +# +# self.global_average_pool = global_average_pool +# if attentional_pool: +# self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) +# self.ln_post = norm_layer(output_dim) +# self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) +# else: +# self.attn_pool = None +# self.ln_post = norm_layer(width) +# self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) +# +# self.init_parameters() +# +# +# def lock(self, unlocked_groups=0, freeze_bn_stats=False): +# for param in self.parameters(): +# param.requires_grad = False +# +# if unlocked_groups != 0: +# groups = [ +# [ +# self.conv1, +# self.positional_embedding, +# self.ln_pre, +# ], +# *zip(self.transformer.resblocks[:-1], [self.class_embedding for i in range(len(self.transformer.resblocks[:-1]))]), +# [ +# self.class_embedding, +# self.transformer.resblocks[-1], +# self.ln_post, +# ], +# [self.proj, self.temporal_embedding] +# ] +# +# def _unlock(x): +# if isinstance(x, Sequence): +# for g in x: +# _unlock(g) +# else: +# if isinstance(x, torch.nn.Parameter): +# x.requires_grad = True +# else: +# for p in x.parameters(): +# p.requires_grad = True +# +# _unlock(groups[-unlocked_groups:]) +# +# def init_parameters(self): +# # FIXME OpenAI CLIP did not define an init for the VisualTransformer +# # TODO experiment if default PyTorch init, below, or alternate init is best. +# +# nn.init.normal_(self.temporal_embedding, std=self.scale) +# # nn.init.normal_(self.class_embedding, std=self.scale) +# # nn.init.normal_(self.positional_embedding, std=self.scale) +# # +# # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) +# # attn_std = self.transformer.width ** -0.5 +# # fc_std = (2 * self.transformer.width) ** -0.5 +# # for block in self.transformer.resblocks: +# # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) +# # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) +# # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) +# # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) +# # +# # if self.text_projection is not None: +# # nn.init.normal_(self.text_projection, std=self.scale) +# # pass +# +# @torch.jit.ignore +# def set_grad_checkpointing(self, enable=True): +# self.transformer.grad_checkpointing = enable +# +# def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +# if self.global_average_pool: +# return x.mean(dim=1), x +# else: +# return x[:, 0], x[:, 1:] +# +# def forward(self, x: torch.Tensor): +# # print('input img', x.shape) +# B, _, T, _, _ = x.shape +# x = rearrange(x, 'b c t h w -> (b t) c h w') +# # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 +# if self.input_patchnorm: +# # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') +# x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], +# self.patch_size[1]) +# x = x.permute(0, 2, 4, 1, 3, 5) +# x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) +# x = self.patchnorm_pre_ln(x) +# x = self.conv1(x) +# else: +# x = self.conv1(x) # shape = [*, width, grid, grid] +# x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] +# x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] +# +# # print('embed img', x.shape) +# # class embeddings and positional embeddings +# x = torch.cat( +# [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), +# x], dim=1) # shape = [*, grid ** 2 + 1, width] +# x = x + self.positional_embedding.to(x.dtype) +# +# n = x.shape[1] +# x = rearrange(x, '(b t) n d -> (b n) t d', t=T) +# x = x + self.temporal_embedding[:, :T, :] +# x = rearrange(x, '(b n) t d -> (b t) n d', n=n) +# +# # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in +# x = self.patch_dropout(x) +# x = self.ln_pre(x) +# +# # print('patch_dropout img', x.shape) +# x = x.permute(1, 0, 2) # NLD -> LND +# # print('permute img', x.shape) +# x = self.transformer(x) +# x = x.permute(1, 0, 2) # LND -> NLD +# +# if self.attn_pool is not None: +# x = self.attn_pool(x) +# x = self.ln_post(x) +# pooled, tokens = self._global_pool(x) +# else: +# pooled, tokens = self._global_pool(x) +# pooled = self.ln_post(pooled) # bt, d +# +# pooled = pooled.reshape(B, T, -1).mean(1) +# if self.proj is not None: +# pooled = pooled @ self.proj +# +# if self.output_tokens: +# return pooled, tokens +# +# return pooled +# +# def _build_vision_tower( +# embed_dim: int, +# vision_cfg: CLIPVisionCfg, +# quick_gelu: bool = False, +# cast_dtype: Optional[torch.dtype] = None +# ): +# if isinstance(vision_cfg, dict): +# vision_cfg = CLIPVisionCfg(**vision_cfg) +# +# # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more +# # memory efficient in recent PyTorch releases (>= 1.10). +# # NOTE: timm models always use native GELU regardless of quick_gelu flag. +# act_layer = QuickGELU if quick_gelu else nn.GELU +# +# vision_heads = vision_cfg.width // vision_cfg.head_width +# norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm +# visual = Video_VisionTransformer( +# num_frames=vision_cfg.num_frames, +# image_size=vision_cfg.image_size, +# patch_size=vision_cfg.patch_size, +# width=vision_cfg.width, +# layers=vision_cfg.layers, +# heads=vision_heads, +# mlp_ratio=vision_cfg.mlp_ratio, +# ls_init_value=vision_cfg.ls_init_value, +# patch_dropout=vision_cfg.patch_dropout, +# input_patchnorm=vision_cfg.input_patchnorm, +# global_average_pool=vision_cfg.global_average_pool, +# attentional_pool=vision_cfg.attentional_pool, +# n_queries=vision_cfg.n_queries, +# attn_pooler_heads=vision_cfg.attn_pooler_heads, +# output_tokens=vision_cfg.output_tokens, +# output_dim=embed_dim, +# act_layer=act_layer, +# norm_layer=norm_layer, +# ) +# +# return visual + + + + +class CLIPEncoderLayer(SpatialCLIPEncoderLayer): + def __init__(self, config: CLIPConfig): + super().__init__(config) + self.T = config.num_frames // config.tube_size + self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames // config.tube_size, config.hidden_size)) + nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5) + + self.embed_dim = config.hidden_size + self.temporal_attn = CLIPAttention(config) + # self.temporal_mlp = CLIPMLP(config) + # self.t_attn_gate = nn.Parameter(torch.tensor([-20.])) + # self.t_ffn_gate = nn.Parameter(torch.tensor([-20.])) + self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + # self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + + # print('input hidden_states', hidden_states.requires_grad) + bt, n, d = hidden_states.shape + t = self.T + + + # time embed + if t != 1: + n = hidden_states.shape[1] + # print(hidden_states.shape, '(b t) n d -> (b n) t d') + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # print(hidden_states.shape) + hidden_states = hidden_states + self.temporal_embedding[:, :t, :] + hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # time attn + residual = hidden_states + hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # hidden_states = self.layer_norm1(hidden_states) # share layernorm + hidden_states = self.temporal_layer_norm1(hidden_states) + + + # print('after t_norm hidden_states', hidden_states.requires_grad) + + hidden_states, attn_weights = self.temporal_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + + # if self.gradient_checkpointing and self.training: + # # print(self.gradient_checkpointing, self.training) + # def create_custom_forward(module): + # def custom_forward(*inputs): + # return module(*inputs, output_attentions) + # + # return custom_forward + # + # hidden_states, attn_weights = torch.utils.checkpoint.checkpoint( + # create_custom_forward(self.temporal_attn), + # hidden_states, + # attention_mask, + # causal_attention_mask, + # ) + # else: + # hidden_states, attn_weights = self.temporal_attn( + # hidden_states=hidden_states, + # attention_mask=attention_mask, + # causal_attention_mask=causal_attention_mask, + # output_attentions=output_attentions, + # ) + + + + # print('after t_attn hidden_states', hidden_states.requires_grad) + + + hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # residual = hidden_states + # hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t) + # # hidden_states = self.layer_norm2(hidden_states) # share layernorm + # hidden_states = self.temporal_layer_norm2(hidden_states) + # hidden_states = self.temporal_mlp(hidden_states) + # hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n) + + # spatial attn + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + + # print('after norm1 hidden_states', hidden_states.requires_grad) + + # if self.gradient_checkpointing and self.training: + # # print(self.gradient_checkpointing, self.training) + # def create_custom_forward(module): + # def custom_forward(*inputs): + # return module(*inputs, output_attentions) + # + # return custom_forward + # + # hidden_states, attn_weights = torch.utils.checkpoint.checkpoint( + # create_custom_forward(self.self_attn), + # hidden_states, + # attention_mask, + # causal_attention_mask, + # ) + # else: + # hidden_states, attn_weights = self.self_attn( + # hidden_states=hidden_states, + # attention_mask=attention_mask, + # causal_attention_mask=causal_attention_mask, + # output_attentions=output_attentions, + # ) + + + + + # print('after self_attn hidden_states', hidden_states.requires_grad) + + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + + # print('after norm2 hidden_states', hidden_states.requires_grad) + + hidden_states = self.mlp(hidden_states) + # if self.gradient_checkpointing and self.training: + # hidden_states = torch.utils.checkpoint.checkpoint(self.mlp, hidden_states) + # else: + # hidden_states = self.mlp(hidden_states) + + + # print('after mlp hidden_states', hidden_states.requires_grad) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + + + +# class ResidualAttentionBlock(SpatialResidualAttentionBlock): +# def __init__(self, +# num_frames: int, +# d_model: int, +# n_head: int, +# mlp_ratio: float = 4.0, +# ls_init_value: float = None, +# act_layer: Callable = nn.GELU, +# norm_layer: Callable = LayerNorm, +# is_cross_attention: bool = False,): +# super().__init__(d_model, n_head, mlp_ratio, ls_init_value, act_layer, norm_layer, is_cross_attention) +# +# self.num_frames = num_frames +# self.time_ln_1 = norm_layer(d_model) +# self.time_attn = nn.MultiheadAttention(d_model, n_head) +# self.time_ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() +# +# def time_attention( +# self, +# q_x: torch.Tensor, +# k_x: Optional[torch.Tensor] = None, +# v_x: Optional[torch.Tensor] = None, +# attn_mask: Optional[torch.Tensor] = None, +# ): +# k_x = k_x if k_x is not None else q_x +# v_x = v_x if v_x is not None else q_x +# +# attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None +# return self.time_attn( +# q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask +# )[0] +# +# def forward( +# self, +# q_x: torch.Tensor, +# k_x: Optional[torch.Tensor] = None, +# v_x: Optional[torch.Tensor] = None, +# attn_mask: Optional[torch.Tensor] = None, +# ): +# k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None +# v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None +# +# n, bt, d = q_x.shape +# t = get_global_value()['NUM_FRAMES'] +# +# # time attn +# # print('q_x', q_x.shape) +# xt = rearrange(q_x, 'n (b t) d -> t (b n) d', t=t) +# # print('xt', xt.shape) +# xt = self.time_ls_1(self.time_attention(q_x=self.time_ln_1(xt), k_x=None, v_x=None, attn_mask=None)) +# # print('time_attention xt', xt.shape) +# q_x = q_x + rearrange(xt, 't (b n) d -> n (b t) d', n=n) +# # print('time_attention q_x', xt.shape) +# +# # spatial attn +# x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) +# +# x = x + self.ls_2(self.mlp(self.ln_2(x))) +# return x + +def print_trainable_parameters(model, msg=''): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + logging.info(f"{msg} Trainable params: {trainable_params} || all params: {all_param} || " + f"trainable: {100 * trainable_params / all_param:.2f}%") + +def convert_model_to_lora(args, model): + if args.clip_type == 'vl' and args.add_time_attn: + target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj", + "temporal_attn.q_proj", "temporal_attn.out_proj", + "temporal_mlp.fc1", "temporal_mlp.fc2" + ] + else: + target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"] + config = LoraConfig( + r=args.lora_r, # 16 + lora_alpha=args.lora_alpha, # 16 + target_modules=target_modules, # self_attn.out_proj + lora_dropout=args.lora_dropout, # 0.1 + bias="none", + modules_to_save=[], + ) + model.vision_model.encoder.is_gradient_checkpointing = False + model.vision_model.encoder = get_peft_model(model.vision_model.encoder, config) + if is_master(args): + print_trainable_parameters(model.vision_model.encoder, msg='The model.vision_model.encoder: ') + # model.text_model.encoder.is_gradient_checkpointing = False + # model.text_model.encoder = get_peft_model(model.text_model.encoder, config) + # if is_master(args): + # print_trainable_parameters(model.text_model.encoder, msg='The model.text_model.encoder: ') + + + +def add_time_attn_block(m: nn.ModuleList, device): + config = m.config + for i, sub_m in enumerate(m.layers): + if isinstance(sub_m, SpatialCLIPEncoderLayer): + oup = CLIPEncoderLayer(config).to(device) + state_dict = sub_m.state_dict() + + new_state_dict = {} + for k, v in state_dict.items(): + if 'self_attn' in k: + new_state_dict[k] = v + # if 'out_proj' in k: + # v = torch.zeros_like(v, dtype=v.dtype, device=v.device) + new_k = 'temporal_attn.' + '.'.join(k.split('.')[1:]) + new_state_dict[new_k] = v + # elif 'mlp' in k: + # new_state_dict[k] = v + # # if 'out_proj' in k: + # # v = torch.zeros_like(v, dtype=v.dtype, device=v.device) + # new_k = 'temporal_mlp.' + '.'.join(k.split('.')[1:]) + # new_state_dict[new_k] = v + elif 'layer_norm1' in k: + new_state_dict[k] = v + new_k = 'temporal_layer_norm1.' + '.'.join(k.split('.')[1:]) + new_state_dict[new_k] = v + # elif 'layer_norm2' in k: + # new_state_dict[k] = v + # new_k = 'temporal_layer_norm2.' + '.'.join(k.split('.')[1:]) + # new_state_dict[new_k] = v + else: + new_state_dict[k] = v + + missing_keys, unexpected_keys = oup.load_state_dict(new_state_dict, strict=False) + # assert missing_keys == ["t_attn_gate", "t_ffn_gate"] + # print(missing_keys, unexpected_keys) + assert missing_keys == ['temporal_embedding'] + assert unexpected_keys == [] + m.layers[i] = oup + +def resize_pos(m: nn.Module, args): + # convert embedding + if args.clip_type == 'al': + m.image_size = [args.num_mel_bins, args.target_length] + m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size + + # m.config.num_channels = 1 + # new_patch_embedding = nn.Conv2d( + # in_channels=m.config.num_channels, + # out_channels=m.embed_dim, + # kernel_size=m.patch_size, + # stride=m.patch_size, + # bias=False, + # ) + # state_dict = m.patch_embedding.state_dict() + # for k, v in state_dict.items(): + # state_dict[k] = torch.mean(v, dim=1, keepdim=True).to(v.dtype) + # m.patch_embedding = new_patch_embedding + # m.patch_embedding.load_state_dict(state_dict) + + # pos resize + old_pos_embed_state_dict = m.position_embedding.state_dict() + old_pos_embed = old_pos_embed_state_dict['weight'] + dtype = old_pos_embed.dtype + grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size] + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + m.to(args.device) + return + + m.num_patches = grid_size[0] * grid_size[1] + m.num_positions = m.num_patches + 1 + m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1))) + new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim) + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = [int(math.sqrt(len(pos_emb_img)))]*2 + + if is_master(args): + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode='bicubic', + antialias=True, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype) + m.position_embedding = new_position_embedding + m.position_embedding.load_state_dict(old_pos_embed_state_dict) + + m.to(args.device) + + +# def i2v_linear_resize_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = True): +# # Rescale the grid of position embeddings when loading from state_dict +# old_pos_embed = state_dict.get('visual.positional_embedding', None) +# if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): +# return +# # grid_size = to_2tuple(model.visual.grid_size) +# grid_size = model.visual.grid_size +# extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) +# # new_seq_len = grid_size[0] * grid_size[1] + extra_tokens +# new_seq_len = grid_size[0] * grid_size[1] * grid_size[2] + extra_tokens +# if new_seq_len == old_pos_embed.shape[0]: +# return +# +# if extra_tokens: +# pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] +# else: +# pos_emb_tok, pos_emb_img = None, old_pos_embed +# # old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) +# +# logging.info('Resizing position embedding grid-size from %s to %s', old_pos_embed.shape[0], new_seq_len) +# # pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) +# pos_emb_img = pos_emb_img.unsqueeze(0).permute(0, 2, 1) +# pos_emb_img = F.interpolate( +# pos_emb_img, +# # size=grid_size, +# size=new_seq_len - extra_tokens, +# mode=interpolation, +# # antialias=antialias, +# # align_corners=False, +# ) +# # pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] +# pos_emb_img = pos_emb_img.permute(0, 2, 1)[0] +# if pos_emb_tok is not None: +# new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) +# else: +# new_pos_embed = pos_emb_img +# state_dict['visual.positional_embedding'] = new_pos_embed +# +# def inflate_patch_embed(state_dict, model): +# old_patch_embed_shape = model.visual.conv1.weight.shape +# new_patch_embed_shape = state_dict['visual.conv1.weight'].shape +# if old_patch_embed_shape == new_patch_embed_shape: +# return +# expanded_weight = state_dict['visual.conv1.weight'].unsqueeze(2).repeat(1, 1, 2, 1, 1) +# state_dict['visual.conv1.weight'] = expanded_weight +# +# +# def load_checkpoint(model, pretrained, strict=True): +# state_dict = load_state_dict(pretrained) +# # detect old format and make compatible with new format +# if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): +# state_dict = convert_to_custom_text_state_dict(state_dict) +# i2v_linear_resize_pos_embed(state_dict, model) +# inflate_patch_embed(state_dict, model) +# incompatible_keys = model.load_state_dict(state_dict, strict=strict) +# return incompatible_keys + diff --git a/open_clip/__init__.py b/open_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb1199b8aa87a919abff1bd0020c6624757ac62 --- /dev/null +++ b/open_clip/__init__.py @@ -0,0 +1,15 @@ +from .coca_model import CoCa +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ + get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub +from .tokenizer import SimpleTokenizer, tokenize, decode +from .transform import image_transform, AugmentationCfg +from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy +from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES diff --git a/open_clip/bpe_simple_vocab_16e6.txt.gz b/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/open_clip/coca_model.py b/open_clip/coca_model.py new file mode 100644 index 0000000000000000000000000000000000000000..039453af70d1c865dd7cc6016f732aff2f7dc3d2 --- /dev/null +++ b/open_clip/coca_model.py @@ -0,0 +1,458 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + +try: + from transformers import ( + BeamSearchScorer, + LogitsProcessorList, + TopPLogitsWarper, + TopKLogitsWarper, + RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MaxLengthCriteria, + StoppingCriteriaList + ) + + GENERATION_TYPES = { + "top_k": TopKLogitsWarper, + "top_p": TopPLogitsWarper, + "beam_search": "beam_search" + } + _has_transformers = True +except ImportError as e: + GENERATION_TYPES = { + "top_k": None, + "top_p": None, + "beam_search": "beam_search" + } + _has_transformers = False + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + attn_pooler_heads: int = 8 + + +def _build_text_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text_decoder = _build_text_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = pad_id + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) + + def _encode_image(self, images, normalize=True): + image_latent, tokens_embs = self.visual(images) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, tokens_embs + + def _encode_text(self, text, normalize=True, embed_cls=True): + text = text[:, :-1] if embed_cls else text # make space for CLS token + text_latent, token_emb = self.text(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize=True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize=True, embed_cls=True): + text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) + return text_latent + + def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): + text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + + # TODO: add assertion to avoid bugs? + labels = text[:, -token_embs.shape[1]:] + + logits = self.text_decoder(image_embs, token_embs) + return { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "labels": labels, + "logit_scale": self.logit_scale.exp() + } + + def generate( + self, + image, + text=None, + seq_len=30, + max_seq_len=77, + temperature=1., + generation_type="beam_search", + top_p=0.1, # keep tokens in the 1 - top_p quantile + top_k=1, # keeps the top_k most probable tokens + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + repetition_penalty=1.0, + fixed_output_length=False # if True output.shape == (batch_size, seq_len) + ): + # taking many ideas and components from HuggingFace GenerationMixin + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation + assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." + assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + + with torch.no_grad(): + sot_token_id = 49406 if sot_token_id is None else sot_token_id + eos_token_id = 49407 if eos_token_id is None else eos_token_id + pad_token_id = self.pad_id if pad_token_id is None else pad_token_id + logit_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(min_seq_len, eos_token_id), + RepetitionPenaltyLogitsProcessor(repetition_penalty), + ] + ) + + if stopping_criteria is None: + stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] + + stopping_criteria = StoppingCriteriaList( + stopping_criteria + ) + + device = image.device + + if generation_type == "beam_search": + output = self._generate_beamsearch( + image_inputs = image, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, + ) + if fixed_output_length and output.shape[1] < seq_len: + return torch.cat( + (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), + dim=1 + ) + return output + + elif generation_type == "top_p": + logit_warper = GENERATION_TYPES[generation_type](top_p) + elif generation_type == "top_k": + logit_warper = GENERATION_TYPES[generation_type](top_k) + else: + raise ValueError( + f"generation_type has to be one of " + f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + ) + + image_latent, image_embs = self._encode_image(image) + + if text is None: + text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + cur_len = text.shape[1] + self.eval() + out = text + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break + else: + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if stopping_criteria(out, None): + break + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out + + def _generate_beamsearch( + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, + ): + device = image_inputs.device + batch_size = image_inputs.shape[0] + image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) + image_latent, image_embs = self._encode_image(image_inputs) + + input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) + input_ids = input_ids * sot_token_id + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=device, + num_beam_groups=num_beam_groups, + ) + # instantiate logits processors + logits_processor = ( + LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) + if logit_processor is None + else logit_processor + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_beam_size, cur_len = input_ids.shape + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while True: + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) + outputs = self( + model_inputs['images'], + model_inputs['text'], + embed_cls=False, + image_latent=image_latent, + image_embs=image_embs + ) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of currentg group only + next_token_logits = outputs['logits'][batch_group_indices, -1, :] + vocab_size = next_token_logits.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # increase cur_len + cur_len = cur_len + 1 + if beam_scorer.is_done or stopping_criteria(input_ids, None): + break + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + ) + return sequence_outputs['sequences'] + + +def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None + return { + "text": input_ids, + "images": image_inputs, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } diff --git a/open_clip/constants.py b/open_clip/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee --- /dev/null +++ b/open_clip/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/open_clip/factory.py b/open_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..1ecdc4ce71b3dc6e3f82fc3d5d416e78022db3c9 --- /dev/null +++ b/open_clip/factory.py @@ -0,0 +1,382 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ + resize_pos_embed, get_cast_dtype +from .coca_model import CoCa +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ + list_pretrained_tags_by_model, download_pretrained_from_hf +from .transform import image_transform, AugmentationCfg +from .tokenizer import HFTokenizer, tokenize + + +HF_HUB_PREFIX = 'hf-hub:' +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, 'r') as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name, cache_dir): + if model_name.startswith(HF_HUB_PREFIX): + tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):], cache_dir) + else: + config = get_model_config(model_name) + tokenizer = HFTokenizer( + config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + return tokenizer + + +def load_state_dict(checkpoint_path: str, map_location='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + + +def load_checkpoint(model, checkpoint_path, strict=True): + state_dict = load_state_dict(checkpoint_path) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + resize_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, +): + has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) + if has_hf_hub_prefix: + model_id = model_name[len(HF_HUB_PREFIX):] + checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + pretrained_cfg = config['preprocess_cfg'] + model_cfg = config['model_cfg'] + else: + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + checkpoint_path = None + pretrained_cfg = {} + model_cfg = None + + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == 'openai': + logging.info(f'Loading pretrained {model_name} from OpenAI.') + model = load_openai_model( + model_name, + precision=precision, + device=device, + cache_dir=cache_dir, + ) + else: + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') + else: + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) + if pretrained_image: + if is_timm_model: + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + else: + assert False, 'pretrained image towers currently only supported for timm models' + + # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + if custom_text: + if is_hf_model: + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf + if "coca" in model_name: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + # manual mixed precision that matches original OpenAI behaviour + if is_timm_model: + # FIXME this is a bit janky, create timm based model in low-precision and + # then cast only LayerNormFp32 instances back to float32 so they don't break. + # Why? The convert_weights_to_lp fn only works with native models. + model.to(device=device, dtype=dtype) + from .transformer import LayerNormFp32 + def _convert_ln(m): + if isinstance(m, LayerNormFp32): + m.weight.data = m.weight.data.to(torch.float32) + m.bias.data = m.bias.data.to(torch.float32) + model.apply(_convert_ln) + else: + model.to(device=device) + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device, dtype=dtype) + else: + model.to(device=device) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD + + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + + if jit: + model = torch.jit.script(model) + + return model + + +def create_loss(args): + if args.distill: + return DistillClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + elif "coca" in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_patch_dropout=force_patch_dropout, + force_image_size=force_image_size, + pretrained_image=pretrained_image, + pretrained_hf=pretrained_hf, + cache_dir=cache_dir, + output_dict=output_dict, + ) + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std, + aug_cfg=aug_cfg, + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess_train, preprocess_val + + +def create_model_from_pretrained( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_image_size=force_image_size, + cache_dir=cache_dir, + require_pretrained=True, + ) + + if not return_transform: + return model + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess diff --git a/open_clip/generation_utils.py b/open_clip/generation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/open_clip/hf_configs.py b/open_clip/hf_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..13c9bfd8c660eac59f1fbc1912b9fccc9c0c625a --- /dev/null +++ b/open_clip/hf_configs.py @@ -0,0 +1,56 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/bert + "bert": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + }, + "pooler": "cls_pooler", + }, +} diff --git a/open_clip/hf_model.py b/open_clip/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..08dbdbcde02b550ca765ca9bcb0b667be2c0443d --- /dev/null +++ b/open_clip/hf_model.py @@ -0,0 +1,193 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" +import re + +import torch +import torch.nn as nn +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ + BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + + class BaseModelOutput: + pass + + + class PretrainedConfig: + pass + +from .hf_configs import arch_dict + + +# utils +def _camel2snake(s): + return re.sub(r'(? torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward(self, image_features, text_features, logit_scale, output_dict=False): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): + + clip_loss = 0 + + if self.clip_loss_weight: + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + if output_dict: + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} + + return clip_loss, caption_loss + + +class DistillClipLoss(ClipLoss): + + def dist_loss(self, teacher_logits, student_logits): + return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) + + def forward( + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + dist_logit_scale, + output_dict=False, + ): + logits_per_image, logits_per_text = \ + self.get_logits(image_features, text_features, logit_scale) + + dist_logits_per_image, dist_logits_per_text = \ + self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) + + labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) + + contrastive_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + distill_loss = ( + self.dist_loss(dist_logits_per_image, logits_per_image) + + self.dist_loss(dist_logits_per_text, logits_per_text) + ) / 2 + + if output_dict: + return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + + return contrastive_loss, distill_loss diff --git a/open_clip/model.py b/open_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f85b68ba23117cb65d082cf5cd4cf7528bab4619 --- /dev/null +++ b/open_clip/model.py @@ -0,0 +1,473 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +from dataclasses import dataclass +import logging +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from .hf_model import HFTextEncoder +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .utils import to_2tuple + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + output_tokens: bool = False + + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + timm_drop: float = 0. # head dropout + timm_drop_path: Optional[float] = None # backbone stochastic depth + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + embed_cls: bool = False + pad_id: int = 0 + output_tokens: bool = False + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def get_input_dtype(precision: str): + input_dtype = None + if precision in ('bf16', 'pure_bf16'): + input_dtype = torch.bfloat16 + elif precision in ('fp16', 'pure_fp16'): + input_dtype = torch.float16 + return input_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + drop=vision_cfg.timm_drop, + drop_path=vision_cfg.timm_drop_path, + patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, + embed_dim=embed_dim, + image_size=vision_cfg.image_size, + ) + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width, + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + input_patchnorm=vision_cfg.input_patchnorm, + global_average_pool=vision_cfg.global_average_pool, + attentional_pool=vision_cfg.attentional_pool, + n_queries=vision_cfg.n_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + output_tokens=vision_cfg.output_tokens, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + proj=text_cfg.proj, + pooler_type=text_cfg.pooler_type, + pretrained=text_cfg.hf_model_pretrained, + output_tokens=text_cfg.output_tokens, + ) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + output_tokens=text_cfg.output_tokens, + pad_id=text_cfg.pad_id, + act_layer=act_layer, + norm_layer=norm_layer, + ) + return text + + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.context_length = text.context_length + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +class CustomTextCLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.context_length = self.text.context_length + self.vocab_size = self.text.vocab_size + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + if isinstance(l, (CLIP, TextTransformer)): + # convert text nn.Parameter projections + attr = getattr(l, "text_projection", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + if isinstance(l, VisionTransformer): + # convert vision nn.Parameter projections + attr = getattr(l, "proj", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers, + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + antialias=antialias, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed diff --git a/open_clip/model_configs/EVA01-g-14-plus.json b/open_clip/model_configs/EVA01-g-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..73f46a71e664fce987218b8eb48903e7bd895f41 --- /dev/null +++ b/open_clip/model_configs/EVA01-g-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/model_configs/EVA01-g-14.json b/open_clip/model_configs/EVA01-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..9d0e80f290d9491b7c46fafd576201b1258165aa --- /dev/null +++ b/open_clip/model_configs/EVA01-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/model_configs/EVA02-B-16.json b/open_clip/model_configs/EVA02-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..3f92357287e1f6600da1e7f391cb6370d7f66de4 --- /dev/null +++ b/open_clip/model_configs/EVA02-B-16.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_base_patch16_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/model_configs/EVA02-E-14-plus.json b/open_clip/model_configs/EVA02-E-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..e250c2a404c86ff168c54cfcf71bc2492be1b74c --- /dev/null +++ b/open_clip/model_configs/EVA02-E-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/model_configs/EVA02-E-14.json b/open_clip/model_configs/EVA02-E-14.json new file mode 100644 index 0000000000000000000000000000000000000000..4b6648e25092b151a9095e0a66956c7ebf835b16 --- /dev/null +++ b/open_clip/model_configs/EVA02-E-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/model_configs/EVA02-L-14-336.json b/open_clip/model_configs/EVA02-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..2bb07f3c082fd88c4e86131b272163aaacfaef9e --- /dev/null +++ b/open_clip/model_configs/EVA02-L-14-336.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "timm_model_name": "eva02_large_patch14_clip_336", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/model_configs/EVA02-L-14.json b/open_clip/model_configs/EVA02-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b4c7f377bc543aa92a145358f2630a58ae9be989 --- /dev/null +++ b/open_clip/model_configs/EVA02-L-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_large_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/model_configs/RN101-quickgelu.json b/open_clip/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/RN101.json b/open_clip/model_configs/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/RN50-quickgelu.json b/open_clip/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/open_clip/model_configs/RN50.json b/open_clip/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/RN50x16.json b/open_clip/model_configs/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/RN50x4.json b/open_clip/model_configs/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/RN50x64.json b/open_clip/model_configs/RN50x64.json new file mode 100644 index 0000000000000000000000000000000000000000..f5aaa2ee3de21ddb03cbd12766a3419bf34898c7 --- /dev/null +++ b/open_clip/model_configs/RN50x64.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 448, + "layers": [ + 3, + 15, + 36, + 10 + ], + "width": 128, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-16-plus-240.json b/open_clip/model_configs/ViT-B-16-plus-240.json new file mode 100644 index 0000000000000000000000000000000000000000..5bbd12bcd01f64d6d0a0aa8316b129327a0d169a --- /dev/null +++ b/open_clip/model_configs/ViT-B-16-plus-240.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 240, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-16-plus.json b/open_clip/model_configs/ViT-B-16-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..5dc1e09baccef2b15055c1bffeb9903e760101c6 --- /dev/null +++ b/open_clip/model_configs/ViT-B-16-plus.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-16.json b/open_clip/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-32-plus-256.json b/open_clip/model_configs/ViT-B-32-plus-256.json new file mode 100644 index 0000000000000000000000000000000000000000..2f09c857de9a4c01ae51297a7e2451984879f9de --- /dev/null +++ b/open_clip/model_configs/ViT-B-32-plus-256.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 256, + "layers": 12, + "width": 896, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-32-quickgelu.json b/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-32.json b/open_clip/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-H-14.json b/open_clip/model_configs/ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74 --- /dev/null +++ b/open_clip/model_configs/ViT-H-14.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-H-16.json b/open_clip/model_configs/ViT-H-16.json new file mode 100644 index 0000000000000000000000000000000000000000..588485455fdf8193ec16474450b94e31c91ea93c --- /dev/null +++ b/open_clip/model_configs/ViT-H-16.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-L-14-280.json b/open_clip/model_configs/ViT-L-14-280.json new file mode 100644 index 0000000000000000000000000000000000000000..2262deaefa82792d35d73c0d7c8e620525092581 --- /dev/null +++ b/open_clip/model_configs/ViT-L-14-280.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 280, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-L-14-336.json b/open_clip/model_configs/ViT-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..8d1f74c2639c3a3705df9865b9c08215675ddc97 --- /dev/null +++ b/open_clip/model_configs/ViT-L-14-336.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-L-14.json b/open_clip/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-L-16-320.json b/open_clip/model_configs/ViT-L-16-320.json new file mode 100644 index 0000000000000000000000000000000000000000..fc2d13ca9ec7f0b56a886ddaf66c4a7ba7a442ba --- /dev/null +++ b/open_clip/model_configs/ViT-L-16-320.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 320, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-L-16.json b/open_clip/model_configs/ViT-L-16.json new file mode 100644 index 0000000000000000000000000000000000000000..82a1cedfa290adacbbdc02bc5d589734c22d41d3 --- /dev/null +++ b/open_clip/model_configs/ViT-L-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-M-16-alt.json b/open_clip/model_configs/ViT-M-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..1a317aad8e02d9c26d2decc7cc49a18dfdf9e0d8 --- /dev/null +++ b/open_clip/model_configs/ViT-M-16-alt.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16, + "ls_init_value": 1e-4 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-M-16.json b/open_clip/model_configs/ViT-M-16.json new file mode 100644 index 0000000000000000000000000000000000000000..f2f3225a46e09237730a151d161f70c86b985172 --- /dev/null +++ b/open_clip/model_configs/ViT-M-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-M-32-alt.json b/open_clip/model_configs/ViT-M-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..fd222aeac0f582ef6a1a33f1b3fec70a5b386ac0 --- /dev/null +++ b/open_clip/model_configs/ViT-M-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-M-32.json b/open_clip/model_configs/ViT-M-32.json new file mode 100644 index 0000000000000000000000000000000000000000..4f718642821035d9776d1e006817d65ede074366 --- /dev/null +++ b/open_clip/model_configs/ViT-M-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-S-16-alt.json b/open_clip/model_configs/ViT-S-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..a8c056555e4da3ba0d1475a61fc316362ecce76f --- /dev/null +++ b/open_clip/model_configs/ViT-S-16-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-S-16.json b/open_clip/model_configs/ViT-S-16.json new file mode 100644 index 0000000000000000000000000000000000000000..1d8504e59658803f3093e5b05de45f30a09b8185 --- /dev/null +++ b/open_clip/model_configs/ViT-S-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-S-32-alt.json b/open_clip/model_configs/ViT-S-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..e1dfdec9824df09a2010e991ccfa1d9ee2f45807 --- /dev/null +++ b/open_clip/model_configs/ViT-S-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-S-32.json b/open_clip/model_configs/ViT-S-32.json new file mode 100644 index 0000000000000000000000000000000000000000..9b8b4191b268de267268cfcb90fc01c6b9df07d8 --- /dev/null +++ b/open_clip/model_configs/ViT-S-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-bigG-14.json b/open_clip/model_configs/ViT-bigG-14.json new file mode 100644 index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7 --- /dev/null +++ b/open_clip/model_configs/ViT-bigG-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-e-14.json b/open_clip/model_configs/ViT-e-14.json new file mode 100644 index 0000000000000000000000000000000000000000..91a0fe14d25a107fb8ec48dd7faae313fd26ed7b --- /dev/null +++ b/open_clip/model_configs/ViT-e-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 56, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.5715, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 36 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-g-14.json b/open_clip/model_configs/ViT-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..8c4b7325cc75b6112be7107d36ae2cb5762d9091 --- /dev/null +++ b/open_clip/model_configs/ViT-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/coca_ViT-B-32.json b/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..7e7eb520a6a0096e5602d509ecd6186e278f4725 --- /dev/null +++ b/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "attn_pooler_heads": 8 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/model_configs/coca_ViT-L-14.json b/open_clip/model_configs/coca_ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3d5ca4ca2338540f06852df5ff35ea6277e64555 --- /dev/null +++ b/open_clip/model_configs/coca_ViT-L-14.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "attn_pooler_heads": 12 + }, + "custom_text": true +} diff --git a/open_clip/model_configs/coca_base.json b/open_clip/model_configs/coca_base.json new file mode 100644 index 0000000000000000000000000000000000000000..cf8c6cecb78a49d7e7140145a0307cbd561077c2 --- /dev/null +++ b/open_clip/model_configs/coca_base.json @@ -0,0 +1,31 @@ +{ + "embed_dim": 512, + "multimodal_cfg": { + "width": 768, + "context_length": 76, + "vocab_size": 64000, + "mlp_ratio": 4, + "layers": 12, + "dim_head": 64, + "heads": 12, + "n_queries": 256, + "attn_pooler_heads": 8 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 64000, + "layers": 12, + "heads": 12, + "width": 768, + "embed_cls": true, + "output_tokens": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/model_configs/coca_roberta-ViT-B-32.json b/open_clip/model_configs/coca_roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..fb46354b95a17a46d7fcfd9d504e917ee6c1608c --- /dev/null +++ b/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "output_tokens": true + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "linear", + "width": 768, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "width": 768, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} diff --git a/open_clip/model_configs/convnext_base.json b/open_clip/model_configs/convnext_base.json new file mode 100644 index 0000000000000000000000000000000000000000..bb6dba181d950ea5081155c90d47e72c94816b80 --- /dev/null +++ b/open_clip/model_configs/convnext_base.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_base_w.json b/open_clip/model_configs/convnext_base_w.json new file mode 100644 index 0000000000000000000000000000000000000000..82ea7ae3659e5514f37ff982f0ab1141dff4bd18 --- /dev/null +++ b/open_clip/model_configs/convnext_base_w.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_base_w_320.json b/open_clip/model_configs/convnext_base_w_320.json new file mode 100644 index 0000000000000000000000000000000000000000..0a07c4e16abaa4015ecc5f82ec845de16e1f9d88 --- /dev/null +++ b/open_clip/model_configs/convnext_base_w_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_large.json b/open_clip/model_configs/convnext_large.json new file mode 100644 index 0000000000000000000000000000000000000000..c4a1fea73dbead71c218a0e74b9b15f9b252e3ef --- /dev/null +++ b/open_clip/model_configs/convnext_large.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_large_d.json b/open_clip/model_configs/convnext_large_d.json new file mode 100644 index 0000000000000000000000000000000000000000..ae8fed21b58e1a6a411daf8b792ee50f0ab42346 --- /dev/null +++ b/open_clip/model_configs/convnext_large_d.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_large_d_320.json b/open_clip/model_configs/convnext_large_d_320.json new file mode 100644 index 0000000000000000000000000000000000000000..54c3df36a6f56ace0b12ada24c13058de96feed8 --- /dev/null +++ b/open_clip/model_configs/convnext_large_d_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_small.json b/open_clip/model_configs/convnext_small.json new file mode 100644 index 0000000000000000000000000000000000000000..3592c2a5cd21aae8d2544931773cf7603f67ea28 --- /dev/null +++ b/open_clip/model_configs/convnext_small.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_small", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_tiny.json b/open_clip/model_configs/convnext_tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..ad11470f5ec40ffec771096971ce58d3d5b9249b --- /dev/null +++ b/open_clip/model_configs/convnext_tiny.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_tiny", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_xlarge.json b/open_clip/model_configs/convnext_xlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..2a909965932eef994177c829fefc2bdc1c219b3f --- /dev/null +++ b/open_clip/model_configs/convnext_xlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 20 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_xxlarge.json b/open_clip/model_configs/convnext_xxlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..23a55a681c346d1a315d8a163c1cb6ad495e6a91 --- /dev/null +++ b/open_clip/model_configs/convnext_xxlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_xxlarge_320.json b/open_clip/model_configs/convnext_xxlarge_320.json new file mode 100644 index 0000000000000000000000000000000000000000..ac5134ca12cbaa97772cde059270d345386a74c7 --- /dev/null +++ b/open_clip/model_configs/convnext_xxlarge_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/mt5-base-ViT-B-32.json b/open_clip/model_configs/mt5-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..58cad89cf0f446bbe15e4e25b1ac43424a828017 --- /dev/null +++ b/open_clip/model_configs/mt5-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "google/mt5-base", + "hf_tokenizer_name": "google/mt5-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/open_clip/model_configs/mt5-xl-ViT-H-14.json b/open_clip/model_configs/mt5-xl-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b432810777ba7269dbb0e89edfe65cdd27e7d255 --- /dev/null +++ b/open_clip/model_configs/mt5-xl-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "google/mt5-xl", + "hf_tokenizer_name": "google/mt5-xl", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/open_clip/model_configs/roberta-ViT-B-32.json b/open_clip/model_configs/roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..ed687d472a73bb2ac96025f355f80437ab14c260 --- /dev/null +++ b/open_clip/model_configs/roberta-ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/open_clip/model_configs/swin_base_patch4_window7_224.json b/open_clip/model_configs/swin_base_patch4_window7_224.json new file mode 100644 index 0000000000000000000000000000000000000000..bd6820f0cf2aa655e0a2723287f4b78895a58e6a --- /dev/null +++ b/open_clip/model_configs/swin_base_patch4_window7_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "swin_base_patch4_window7_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/vit_medium_patch16_gap_256.json b/open_clip/model_configs/vit_medium_patch16_gap_256.json new file mode 100644 index 0000000000000000000000000000000000000000..8843eaf08cad16c3e7b5f496fd650715c9573f65 --- /dev/null +++ b/open_clip/model_configs/vit_medium_patch16_gap_256.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_medium_patch16_gap_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json b/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json new file mode 100644 index 0000000000000000000000000000000000000000..ed217b202d5e6071c5307f4547c97ff4cfe2abd1 --- /dev/null +++ b/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_relpos_medium_patch16_cls_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json b/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..751bccc2c6fc41bc4ff20182de88d86739d518d9 --- /dev/null +++ b/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-base", + "hf_tokenizer_name": "xlm-roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json b/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..31f271faa9bbb7a9da53900b483a4c00a16f3c4a --- /dev/null +++ b/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-large", + "hf_tokenizer_name": "xlm-roberta-large", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/open_clip/modified_resnet.py b/open_clip/modified_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c0b033a80e7d08a20a367050c5b1bc5d5292e7 --- /dev/null +++ b/open_clip/modified_resnet.py @@ -0,0 +1,181 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from open_clip.utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0., + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/open_clip/openai.py b/open_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2c0235245c2e4f1217b3b2bfaf2acf78e74981 --- /dev/null +++ b/open_clip/openai.py @@ -0,0 +1,90 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag('openai') + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + + if get_pretrained_url(name, 'openai'): + model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location="cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + state_dict = torch.load(model_path, map_location="cpu") + + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + # FIXME support pure fp16/bf16 precision modes + if precision != 'fp16': + model.float() + if precision == 'bf16': + # for bf16, convert back to low-precision + convert_weights_to_lp(model, dtype=torch.bfloat16) + + # add mean / std attributes for consistency with OpenCLIP models + model.visual.image_mean = OPENAI_DATASET_MEAN + model.visual.image_std = OPENAI_DATASET_STD + return model diff --git a/open_clip/pretrained.py b/open_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..1465a2325652be7e7a1d7563698e38b9ec408cc6 --- /dev/null +++ b/open_clip/pretrained.py @@ -0,0 +1,427 @@ +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Union + +from tqdm import tqdm + +from .version import __version__ + +try: + from huggingface_hub import hf_hub_download + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + + +_RN50 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN50_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN101 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN101_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN50x4 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), +) + +_RN50x16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), +) + +_RN50x64 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), +) + +_VITB32 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), + # DataComp-M models + datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), + commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), + commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), + commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), + commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), + commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), + commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), + # DataComp-S models + datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), + commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), + commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), + commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), + commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), + commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), + commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), +) + +_VITB32_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), +) + +_VITB16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), + # DataComp-L models + datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), + commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), + commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), + commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), + commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), + commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), + commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), + commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), + commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), + commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), +) + +_VITL14_336 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), +) + +_robertaViTB32 = dict( + laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), +) + +_xlmRobertaBaseViTB32 = dict( + laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), +) + +_xlmRobertaLargeFrozenViTH14 = dict( + frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), +) + +_convnext_base = dict( + laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), +) + +_convnext_base_w = dict( + laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), + laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), +) + +_convnext_base_w_320 = dict( + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), + laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), +) + +_convnext_large_d = dict( + laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), +) + +_convnext_large_d_320 = dict( + laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), + laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), +) + +_convnext_xxlarge = dict( + laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), + laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), + laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), +) + +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + +_coca_VITL14 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') +) + + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "RN50x64": _RN50x64, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-B-16-plus-240": _VITB16_PLUS_240, + "ViT-L-14": _VITL14, + "ViT-L-14-336": _VITL14_336, + "ViT-H-14": _VITH14, + "ViT-g-14": _VITg14, + "ViT-bigG-14": _VITbigG14, + "roberta-ViT-B-32": _robertaViTB32, + "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, + "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, + "convnext_base": _convnext_base, + "convnext_base_w": _convnext_base_w, + "convnext_base_w_320": _convnext_base_w_320, + "convnext_large_d": _convnext_large_d, + "convnext_large_d_320": _convnext_large_d_320, + "convnext_xxlarge": _convnext_xxlarge, + "coca_ViT-B-32": _coca_VITB32, + "coca_ViT-L-14": _coca_VITL14, + "EVA01-g-14": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt + laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), + ), + "EVA01-g-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt + merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), + ), + "EVA02-B-16": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt + merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), + ), + "EVA02-L-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt + merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), + ), + "EVA02-L-14-336": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt + merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), + ), + "EVA02-E-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt + laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), + ), + "EVA02-E-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt + laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), + ) +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = '' + if not cfg: + return target + + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/open_clip/push_to_hf_hub.py b/open_clip/push_to_hf_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6271da1d35e36ea22e92d339dc9465d0793249 --- /dev/null +++ b/open_clip/push_to_hf_hub.py @@ -0,0 +1,280 @@ +import argparse +import json +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Tuple, Union + +import torch + +try: + from huggingface_hub import ( + create_repo, + get_hf_file_metadata, + hf_hub_download, + hf_hub_url, + repo_type_and_id_from_hf_id, + upload_folder, + list_repo_files, + ) + from huggingface_hub.utils import EntryNotFoundError + _has_hf_hub = True +except ImportError: + _has_hf_hub = False + +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False + +from .factory import create_model_from_pretrained, get_model_config, get_tokenizer +from .tokenizer import HFTokenizer + +# Default name for a weights file hosted on the Huggingface Hub. +HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl +HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version +HF_CONFIG_NAME = 'open_clip_config.json' + +def save_config_for_hf( + model, + config_path: str, + model_config: Optional[dict] +): + preprocess_cfg = { + 'mean': model.visual.image_mean, + 'std': model.visual.image_std, + } + hf_config = { + 'model_cfg': model_config, + 'preprocess_cfg': preprocess_cfg, + } + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def save_for_hf( + model, + tokenizer: HFTokenizer, + model_config: dict, + save_directory: str, + safe_serialization: Union[bool, str] = False, + skip_weights : bool = False, +): + config_filename = HF_CONFIG_NAME + + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + if not skip_weights: + tensors = model.state_dict() + if safe_serialization is True or safe_serialization == "both": + assert _has_safetensors, "`pip install safetensors` to use .safetensors" + safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) + if safe_serialization is False or safe_serialization == "both": + torch.save(tensors, save_directory / HF_WEIGHTS_NAME) + + tokenizer.save_pretrained(save_directory) + + config_path = save_directory / config_filename + save_config_for_hf(model, config_path, model_config=model_config) + + +def push_to_hf_hub( + model, + tokenizer, + model_config: Optional[dict], + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, + safe_serialization: Union[bool, str] = False, +): + if not isinstance(tokenizer, HFTokenizer): + # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 + tokenizer = HFTokenizer('openai/clip-vit-large-patch14') + + # Create repo if it doesn't exist yet + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) + + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if repo already exists and determine what needs updating + repo_exists = False + repo_files = {} + try: + repo_files = set(list_repo_files(repo_id)) + repo_exists = True + except Exception as e: + print('Repo does not exist', e) + + try: + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf( + model, + tokenizer=tokenizer, + model_config=model_config, + save_directory=tmpdir, + safe_serialization=safe_serialization, + ) + + # Add readme if it does not exist + if not has_readme: + model_card = model_card or {} + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / "README.md" + readme_text = generate_readme(model_card, model_name) + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) + + +def push_pretrained_to_hf_hub( + model_name, + pretrained: str, + repo_id: str, + precision: str = 'fp32', + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, +): + model, preprocess_eval = create_model_from_pretrained( + model_name, + pretrained=pretrained, + precision=precision, + image_mean=image_mean, + image_std=image_std, + ) + + model_config = get_model_config(model_name) + assert model_config + + tokenizer = get_tokenizer(model_name) + + push_to_hf_hub( + model=model, + tokenizer=tokenizer, + model_config=model_config, + repo_id=repo_id, + commit_message=commit_message, + token=token, + revision=revision, + private=private, + create_pr=create_pr, + model_card=model_card, + safe_serialization='both', + ) + + +def generate_readme(model_card: dict, model_name: str): + readme_text = "---\n" + readme_text += "tags:\n- clip\n" + readme_text += "library_name: open_clip\n" + readme_text += "pipeline_tag: zero-shot-image-classification\n" + readme_text += f"license: {model_card.get('license', 'mit')}\n" + if 'details' in model_card and 'Dataset' in model_card['details']: + readme_text += 'datasets:\n' + readme_text += f"- {model_card['details']['Dataset'].lower()}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + if isinstance(v, (list, tuple)): + readme_text += f"- **{k}:**\n" + for vi in v: + readme_text += f" - {vi}\n" + elif isinstance(v, dict): + readme_text += f"- **{k}:**\n" + for ki, vi in v.items(): + readme_text += f" - {ki}: {vi}\n" + else: + readme_text += f"- **{k}:** {v}\n" + if 'usage' in model_card: + readme_text += f"\n## Model Usage\n" + readme_text += model_card['usage'] + readme_text += '\n' + + if 'comparison' in model_card: + readme_text += f"\n## Model Comparison\n" + readme_text += model_card['comparison'] + readme_text += '\n' + + if 'citation' in model_card: + readme_text += f"\n## Citation\n" + if not isinstance(model_card['citation'], (list, tuple)): + citations = [model_card['citation']] + else: + citations = model_card['citation'] + for c in citations: + readme_text += f"```bibtex\n{c}\n```\n" + + return readme_text + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") + parser.add_argument( + "--model", type=str, help="Name of the model to use.", + ) + parser.add_argument( + "--pretrained", type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--repo-id", type=str, + help="Destination HF Hub repo-id ie 'organization/model_id'.", + ) + parser.add_argument( + "--precision", type=str, default='fp32', + ) + parser.add_argument( + '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override default image mean value of dataset') + parser.add_argument( + '--image-std', type=float, nargs='+', default=None, metavar='STD', + help='Override default image std deviation of of dataset') + args = parser.parse_args() + + print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') + + # FIXME add support to pass model_card json / template from file via cmd line + + push_pretrained_to_hf_hub( + args.model, + args.pretrained, + args.repo_id, + precision=args.precision, + image_mean=args.image_mean, # override image mean/std if trained w/ non defaults + image_std=args.image_std, + ) + + print(f'{args.model} saved.') diff --git a/open_clip/timm_model.py b/open_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3f595d67cdedd142b6312d26924e8e58c67086 --- /dev/null +++ b/open_clip/timm_model.py @@ -0,0 +1,149 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + proj_bias=False, + drop=0., + drop_path=None, + patch_drop=None, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + self.image_size = to_2tuple(image_size) + + # setup kwargs that may not be common across all models + timm_kwargs = {} + if drop_path is not None: + timm_kwargs['drop_path_rate'] = drop_path + if patch_drop is not None: + timm_kwargs['patch_drop_rate'] = patch_drop + + custom_pool = pool in ('abs_attn', 'rot_attn') + if not proj and not custom_pool: + # use network classifier head as projection if no proj specified and no custom pooling used + self.trunk = timm.create_model( + model_name, + num_classes=embed_dim, + global_pool=pool, + pretrained=pretrained, + **timm_kwargs, + ) + prev_chs = embed_dim + else: + self.trunk = timm.create_model( + model_name, + pretrained=pretrained, + **timm_kwargs, + ) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if custom_pool: + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + + # Add custom pooling to head + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) + else: + assert not proj, f'Unknown projection type {proj}.' + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/open_clip/tokenizer.py b/open_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..97ac3e804fd1d421183c2764c020a723149825d1 --- /dev/null +++ b/open_clip/tokenizer.py @@ -0,0 +1,214 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__(self, tokenizer_name: str, cache_dir=None): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=cache_dir) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + output = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ) + return output.input_ids, output.attention_mask diff --git a/open_clip/transform.py b/open_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..748884a3c7cb7ece1ca521ca1dbf40bb74855007 --- /dev/null +++ b/open_clip/transform.py @@ -0,0 +1,133 @@ +import warnings +from dataclasses import dataclass, asdict +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None + interpolation: Optional[str] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + if isinstance(aug_cfg, dict): + aug_cfg = AugmentationCfg(**aug_cfg) + else: + aug_cfg = aug_cfg or AugmentationCfg() + normalize = Normalize(mean=mean, std=std) + if is_train: + aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + use_timm = aug_cfg_dict.pop('use_timm', False) + if use_timm: + from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): + assert len(image_size) >= 2 + input_size = (3,) + image_size[-2:] + else: + input_size = (3, image_size, image_size) + # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time + aug_cfg_dict.setdefault('interpolation', 'random') + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + train_transform = create_transform( + input_size=input_size, + is_training=True, + hflip=0., + mean=mean, + std=std, + re_mode='pixel', + **aug_cfg_dict, + ) + else: + train_transform = Compose([ + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ]) + if aug_cfg_dict: + warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + return train_transform + else: + if resize_longest_max: + transforms = [ + ResizeMaxSize(image_size, fill=fill_color) + ] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) diff --git a/open_clip/transformer.py b/open_clip/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3c04436fb0cc26dbcbdf72e52397da93744b078c --- /dev/null +++ b/open_clip/transformer.py @@ -0,0 +1,737 @@ +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence, Tuple + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x, B, T): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + if T == 1: + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + else: + rand = torch.randn(B, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1) + patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n') + + + + + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): + return self.resblocks[0].mlp.c_fc.int8_original_dtype + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + input_patchnorm: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False + ): + super().__init__() + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 + self.input_patchnorm = input_patchnorm + + if input_patchnorm: + patch_input_dim = patch_height * patch_width * 3 + self.patchnorm_pre_ln = LayerNorm(patch_input_dim) + self.conv1 = nn.Linear(patch_input_dim, width) + else: + self.patchnorm_pre_ln = nn.Identity() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.global_average_pool = global_average_pool + if attentional_pool: + self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + else: + self.attn_pool = None + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + + def forward(self, x: torch.Tensor): + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) + x = self.patchnorm_pre_ln(x) + x = self.conv1(x) + else: + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if self.attn_pool is not None: + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + else: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, + output_tokens: bool = False, + ): + super().__init__() + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _repeat(self, t, N: int): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + if self.cls_emb is not None: + pooled, tokens = x[:, -1], x[:, :-1] + pooled = self.ln_final(pooled) + else: + x = self.ln_final(x) + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + + if self.text_projection is not None: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection + + return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/open_clip/utils.py b/open_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0bb8868ae1f2d31493ca32b73accd6bf1d3cdb --- /dev/null +++ b/open_clip/utils.py @@ -0,0 +1,89 @@ +from itertools import repeat +import collections.abc + +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) + +# Replaces all linear layers with linear_replacement +# TODO: add int8 support for other linear layers including attn and convnets +def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, include_modules, copy_weights) + + if isinstance(module, torch.nn.Linear) and name in include_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight.data.copy_(old_module.weight.data) + if model._modules[name].bias is not None: + model._modules[name].bias.data.copy_(old_module.bias) + + return model + +def convert_int8_model_to_inference_mode(model): + for m in model.modules(): + if hasattr(m, 'prepare_for_eval'): + int8_original_dtype = m.weight.dtype + m.prepare_for_eval() + m.int8_original_dtype = int8_original_dtype \ No newline at end of file diff --git a/open_clip/version.py b/open_clip/version.py new file mode 100644 index 0000000000000000000000000000000000000000..a910817da22d06aa0244c6d488b40d30da2bfb7e --- /dev/null +++ b/open_clip/version.py @@ -0,0 +1 @@ +__version__ = '2.20.0' diff --git a/open_clip/zero_shot_classifier.py b/open_clip/zero_shot_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a5267cea4119994e30bb4830a6744cf25bdbaf --- /dev/null +++ b/open_clip/zero_shot_classifier.py @@ -0,0 +1,111 @@ +from functools import partial +from itertools import islice +from typing import Callable, List, Optional, Sequence, Union + +import torch +import torch.nn.functional as F + + +def batched(iterable, n): + """Batch data into lists of length *n*. The last batch may be shorter. + NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch + + +def build_zero_shot_classifier( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + num_classes_per_batch: Optional[int] = 10, + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names in batches + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + num_classes_per_batch: The number of classes to batch together in each forward, all if None + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + use_format = isinstance(templates[0], str) + num_templates = len(templates) + num_classes = len(classnames) + if use_tqdm: + import tqdm + num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) + iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) + else: + iter_wrap = iter + + def _process_batch(batch_classnames): + num_batch_classes = len(batch_classnames) + texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] + input_ids, attention_mask = tokenizer(texts) + input_ids, attention_mask = input_ids.to(device), attention_mask.to(device) + class_embeddings = F.normalize(model.encode_text(input_ids, attention_mask), dim=-1) + class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) + class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) + class_embeddings = class_embeddings.T + return class_embeddings + + with torch.no_grad(): + if num_classes_per_batch: + batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] + zeroshot_weights = torch.cat(batched_embeds, dim=1) + else: + zeroshot_weights = _process_batch(classnames) + return zeroshot_weights + + +def build_zero_shot_classifier_legacy( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names 1 by 1 + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + if use_tqdm: + import tqdm + iter_wrap = tqdm.tqdm + else: + iter_wrap = iter + + use_format = isinstance(templates[0], str) + + with torch.no_grad(): + zeroshot_weights = [] + for classname in iter_wrap(classnames): + texts = [template.format(classname) if use_format else template(classname) for template in templates] + texts = tokenizer(texts).to(device) # tokenize + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) + + return zeroshot_weights + diff --git a/open_clip/zero_shot_metadata.py b/open_clip/zero_shot_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb452bbb6e27b71cff1dd27e2bb263259b9363f --- /dev/null +++ b/open_clip/zero_shot_metadata.py @@ -0,0 +1,266 @@ + +OPENAI_IMAGENET_TEMPLATES = ( + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +) + + +# a much smaller subset of above prompts +# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +SIMPLE_IMAGENET_TEMPLATES = ( + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +) + + +IMAGENET_CLASSNAMES = ( + "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", + "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", + "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", + "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", + "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", + "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", + "box turtle", "banded gecko", "green iguana", "Carolina anole", + "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", + "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", + "American alligator", "triceratops", "worm snake", "ring-necked snake", + "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", + "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", + "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", + "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", + "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", + "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", + "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", + "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", + "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", + "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", + "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", + "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", + "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", + "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", + "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", + "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", + "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", + "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", + "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", + "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", + "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", + "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", + "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", + "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", + "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", + "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", + "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", + "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", + "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", + "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", + "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", + "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", + "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", + "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", + "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", + "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", + "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", + "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", + "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", + "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", + "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", + "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", + "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", + "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", + "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", + "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", + "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", + "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", + "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", + "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", + "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", + "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", + "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", + "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", + "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", + "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", + "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", + "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", + "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", + "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", + "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", + "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", + "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", + "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", + "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", + "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", + "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", + "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", + "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", + "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", + "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", + "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", + "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", + "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", + "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", + "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", + "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", + "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", + "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", + "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", + "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", + "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", + "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", + "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", + "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", + "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", + "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", + "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", + "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", + "freight car", "French horn", "frying pan", "fur coat", "garbage truck", + "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", + "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", + "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", + "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", + "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", + "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", + "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", + "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", + "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", + "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", + "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", + "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", + "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", + "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", + "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", + "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", + "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", + "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", + "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", + "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", + "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", + "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", + "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", + "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", + "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", + "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", + "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", + "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", + "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", + "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", + "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", + "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", + "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", + "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", + "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", + "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", + "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", + "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", + "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", + "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", + "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", + "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", + "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", + "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", + "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", + "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", + "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", + "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", + "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", + "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", + "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", + "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", + "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", + "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", + "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", + "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", + "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", + "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", + "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", + "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", + "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", + "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", + "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", + "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", + "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper" +) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..58118fbdc8b3b1b951ff0ce94951e753f5b23071 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +accelerate==0.20.3 +datasets==2.13.0 +decord==0.6.0 +einops==0.6.1 +evaluate==0.4.0 +ftfy==6.1.1 +iopath==0.1.10 +opencv-python==4.7.0.72 +peft @ git+https://github.com/huggingface/peft@08cb3dde577747f6ca6638c884fd66fd16cf2e9d +pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d +scipy==1.10.1 +tensorboardX==2.6.1 +tokenizers==0.13.3 +tqdm==4.65.0 +numpy==1.23.0 +scikit-learn==1.3.0 +braceexpand==0.1.7 +webdataset==0.2.48 +transformers==4.30.2 +urllib3==1.26.15 +SoundFile diff --git a/scripts/audio_language/eval.sh b/scripts/audio_language/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..dcd28e0f4434bd664aadf696f42117498aba29f9 --- /dev/null +++ b/scripts/audio_language/eval.sh @@ -0,0 +1,25 @@ + +CACHE_DIR="path/to/pretrained/weight" +RESUME="audio_language.pt" +ANNOTATION="path/to/data" +# this script is for 512 total batch_size (n(16) GPUs * batch_size(32) * accum_freq(1)) +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nnodes=2 --nproc_per_node 8 \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 4800000 \ + --clip-type "al" --num_mel_bins 126 --target_length 1036 --audio_sample_rate 16000 --audio_mean -4.2677393 --audio_std 4.5689974 \ + --lock-text --lock-image --text-type "polish_mplug" \ + --init-temp 0.07 --learn-temp \ + --model "ViT-L-14" --cache-dir ${CACHE_DIR} \ + --convert_to_lora --lora_r 16 \ + --lr 1e-3 --coef-lr 1 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 1 --force-patch-dropout 0.1 \ + --epochs 16 --batch-size 16 --accum-freq 4 --warmup 2000 \ + --precision "amp" --workers 10 --video-decode-backend "imgs" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume ${RESUME} \ + --do_eval \ + --val_a_cls_data "ESC50" "VGGSound" "Audioset" \ + --val_al_ret_data "Clotho" "Audiocaps" + diff --git a/scripts/audio_language/train.sh b/scripts/audio_language/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..bbbf1d389c9a6ec65bb0560e897f614f430bf83f --- /dev/null +++ b/scripts/audio_language/train.sh @@ -0,0 +1,23 @@ + +CACHE_DIR="path/to/pretrained/weight" +ANNOTATION="path/to/data" +# this script is for 1024 total batch_size (n(16) GPUs * batch_size(16) * accum_freq(4)) +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nnodes=2 --nproc_per_node 8 \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 4800000 \ + --clip-type "al" --num_mel_bins 126 --target_length 1036 --audio_sample_rate 16000 --audio_mean -4.2677393 --audio_std 4.5689974 \ + --lock-text --lock-image --text-type "polish_mplug" \ + --init-temp 0.07 --learn-temp \ + --model "ViT-L-14" --cache-dir ${CACHE_DIR} \ + --convert_to_lora --lora_r 16 \ + --lr 1e-3 --coef-lr 1 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 1 --force-patch-dropout 0.1 \ + --epochs 16 --batch-size 16 --accum-freq 4 --warmup 2000 \ + --precision "amp" --workers 10 --video-decode-backend "imgs" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume "latest" \ + --do_eval --do_train \ + --val_a_cls_data "ESC50" "VGGSound" "Audioset" \ + --val_al_ret_data "Clotho" "Audiocaps" diff --git a/scripts/depth_language/eval.sh b/scripts/depth_language/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..d7cc2760cfeea1b55acd834a1642eb0e16e3a8bd --- /dev/null +++ b/scripts/depth_language/eval.sh @@ -0,0 +1,25 @@ + +CACHE_DIR="path/to/pretrained/weight" +RESUME="thermal_language.pt" +ANNOTATION="path/to/data" +# this script is for 1024 total batch_size (n(8) GPUs * batch_size(128) * accum_freq(1)) +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nnodes=$HOST_NUM --node_rank=$INDEX --nproc_per_node $HOST_GPU_NUM --master_addr $CHIEF_IP \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 3020000 \ + --clip-type "dl" --max-depth 10 \ + --lock-text --lock-image --text-type "polish_mplug" \ + --init-temp 0.07 --learn-temp \ + --model "ViT-L-14" --cache-dir ${CACHE_DIR} \ + --convert_to_lora --lora_r 2 \ + --lr 5e-4 --coef-lr 1e-3 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 1 --force-patch-dropout 0.5 \ + --epochs 1 --batch-size 128 --accum-freq 1 --warmup 200 \ + --precision "amp" --workers 10 --video-decode-backend "imgs" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume ${RESUME} \ + --do_eval \ + --val_d_cls_data "NYUV2" + + diff --git a/scripts/depth_language/train.sh b/scripts/depth_language/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..203e6f445cd7b04b2db40c856dfe8ee2beb91fe5 --- /dev/null +++ b/scripts/depth_language/train.sh @@ -0,0 +1,25 @@ + +CACHE_DIR="path/to/pretrained/weight" +ANNOTATION="path/to/data" +# this script is for 1024 total batch_size (n(8) GPUs * batch_size(128) * accum_freq(1)) +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nnodes=$HOST_NUM --node_rank=$INDEX --nproc_per_node $HOST_GPU_NUM --master_addr $CHIEF_IP \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 3020000 \ + --clip-type "dl" --max-depth 10 \ + --do_train \ + --lock-text --lock-image --text-type "polish_mplug" \ + --init-temp 0.07 --learn-temp \ + --model "ViT-L-14" --cache-dir ${CACHE_DIR} \ + --convert_to_lora --lora_r 2 \ + --lr 5e-4 --coef-lr 1e-3 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 1 --force-patch-dropout 0.5 \ + --epochs 1 --batch-size 128 --accum-freq 1 --warmup 200 \ + --precision "amp" --workers 10 --video-decode-backend "imgs" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume "latest" \ + --do_eval \ + --val_d_cls_data "NYUV2" + + diff --git a/scripts/thermal_language/eval.sh b/scripts/thermal_language/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..c996f947d97fb1730d2404da098296879dd576e7 --- /dev/null +++ b/scripts/thermal_language/eval.sh @@ -0,0 +1,24 @@ + + +CACHE_DIR="path/to/pretrained/weight" +RESUME="thermal_language.pt" +ANNOTATION="path/to/data" +# this script is for 1024 total batch_size (n(8) GPUs * batch_size(128) * accum_freq(1)) +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nnodes=$HOST_NUM --node_rank=$INDEX --nproc_per_node $HOST_GPU_NUM --master_addr $CHIEF_IP \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 3020000 \ + --clip-type "tl" \ + --lock-text --lock-image --text-type "polish_mplug" \ + --init-temp 0.07 --learn-temp \ + --model "ViT-L-14" --cache-dir ${CACHE_DIR} \ + --convert_to_lora --lora_r 2 \ + --lr 1e-4 --coef-lr 1e-3 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 1 --force-patch-dropout 0.5 \ + --epochs 1 --batch-size 128 --accum-freq 1 --warmup 200 \ + --precision "amp" --workers 10 --video-decode-backend "imgs" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume ${RESUME} \ + --do_eval \ + --val_t_cls_data "LLVIP" "FLIRV1" "FLIRV2" diff --git a/scripts/thermal_language/train.sh b/scripts/thermal_language/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..0075fe2066dd5b4176dd29d7a973716d8ea8fbd5 --- /dev/null +++ b/scripts/thermal_language/train.sh @@ -0,0 +1,24 @@ + + +CACHE_DIR="path/to/pretrained/weight" +ANNOTATION="path/to/data" +# this script is for 1024 total batch_size (n(8) GPUs * batch_size(128) * accum_freq(1)) +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nnodes=$HOST_NUM --node_rank=$INDEX --nproc_per_node $HOST_GPU_NUM --master_addr $CHIEF_IP \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 3020000 \ + --clip-type "tl" \ + --do_train \ + --lock-text --lock-image --text-type "polish_mplug" \ + --init-temp 0.07 --learn-temp \ + --model "ViT-L-14" --cache-dir ${CACHE_DIR} \ + --convert_to_lora --lora_r 2 \ + --lr 1e-4 --coef-lr 1e-3 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 1 --force-patch-dropout 0.5 \ + --epochs 1 --batch-size 128 --accum-freq 1 --warmup 200 \ + --precision "amp" --workers 10 --video-decode-backend "imgs" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume "latest" \ + --do_eval \ + --val_t_cls_data "LLVIP" "FLIRV1" "FLIRV2" diff --git a/scripts/video_language/eval.sh b/scripts/video_language/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..f5470c331a0d15ec86601a214c9f19688b6d9a65 --- /dev/null +++ b/scripts/video_language/eval.sh @@ -0,0 +1,23 @@ + +CACHE_DIR="path/to/pretrained/weight" +RESUME="video_language.pt" +ANNOTATION="path/to/data" +# this script is for 640 total batch_size (n(16) GPUs * batch_size(10) * accum_freq(4)) +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nnodes=$HOST_NUM --node_rank=$INDEX --nproc_per_node $HOST_GPU_NUM --master_addr $CHIEF_IP \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 3020000 \ + --clip-type "vl" --add-time-attn \ + --lock-text --lock-image --text-type "polish_mplug" \ + --init-temp 0.07 --learn-temp \ + --model "ViT-L-14" --cache-dir ${CACHE_DIR} \ + --convert_to_lora --lora_r 16 \ + --lr 1e-4 --coef-lr 1 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 8 --force-patch-dropout 0.3 \ + --epochs 16 --batch-size 10 --accum-freq 4 --warmup 2000 \ + --precision "amp" --workers 10 --video-decode-backend "imgs" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume "latest" \ + --do_eval \ + --val_vl_ret_data "msrvtt" "msvd" "activity" "didemo" diff --git a/scripts/video_language/train.sh b/scripts/video_language/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..39184aa872cafed02d011903c174eed1338f5acc --- /dev/null +++ b/scripts/video_language/train.sh @@ -0,0 +1,23 @@ + +CACHE_DIR="path/to/pretrained/weight" +ANNOTATION="path/to/data" +# this script is for 640 total batch_size (n(16) GPUs * batch_size(10) * accum_freq(4)) +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nnodes=$HOST_NUM --node_rank=$INDEX --nproc_per_node $HOST_GPU_NUM --master_addr $CHIEF_IP \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 3020000 --add-time-attn \ + --clip-type "vl" \ + --do_train \ + --lock-text --lock-image --text-type "polish_mplug" \ + --init-temp 0.07 --learn-temp \ + --model "ViT-L-14" --cache-dir ${CACHE_DIR} \ + --convert_to_lora --lora_r 16 \ + --lr 1e-4 --coef-lr 1 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 8 --force-patch-dropout 0.3 \ + --epochs 16 --batch-size 10 --accum-freq 4 --warmup 2000 \ + --precision "amp" --workers 10 --video-decode-backend "imgs" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume "latest" \ + --do_eval \ + --val_vl_ret_data "msrvtt" "msvd" "activity" "didemo" diff --git a/scripts/video_language/train_1.5_huge.sh b/scripts/video_language/train_1.5_huge.sh new file mode 100644 index 0000000000000000000000000000000000000000..e26546c058b3c556d4de084f8f3905d573f00a4b --- /dev/null +++ b/scripts/video_language/train_1.5_huge.sh @@ -0,0 +1,21 @@ +CACHE_DIR="path/to/pretrained/weight" +ANNOTATION="path/to/data" +# this script is for 1024 total batch_size (n(64) GPUs * batch_size(16) * accum_freq(1)) +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nnodes=$HOST_NUM --node_rank=$INDEX --nproc_per_node $HOST_GPU_NUM --master_addr $CHIEF_IP \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 10076613 \ + --clip-type "vl_new" --add-time-attn \ + --do_train \ + --lock-text --lock-image --text-type "mix" \ + --init-temp 0.07 --learn-temp --grad-checkpointing \ + --model "ViT-H-14" --cache-dir ${CACHE_DIR} \ + --lr 1e-4 --coef-lr 1 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 8 --tube-size 1 --force-patch-dropout 0.3 \ + --epochs 6 --batch-size 16 --accum-freq 1 --warmup 2000 \ + --precision "amp" --workers 10 --video-decode-backend "decord" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume "latest" \ + --do_eval \ + --val_vl_ret_data "msrvtt" "msvd" "activity" "didemo" \ No newline at end of file diff --git a/scripts/video_language/train_1.5_large.sh b/scripts/video_language/train_1.5_large.sh new file mode 100644 index 0000000000000000000000000000000000000000..8b4254034ab9542fb2d085b25f8b05515ed6f8e0 --- /dev/null +++ b/scripts/video_language/train_1.5_large.sh @@ -0,0 +1,21 @@ +CACHE_DIR="path/to/pretrained/weight" +ANNOTATION="path/to/data" +# this script is for 1024 total batch_size (n(64) GPUs * batch_size(16) * accum_freq(1)) +cd /path/to/LanguageBind +TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torchrun --nnodes=$HOST_NUM --node_rank=$INDEX --nproc_per_node $HOST_GPU_NUM --master_addr $CHIEF_IP \ + -m main \ + --train-data ${ANNOTATION} \ + --train-num-samples 10076613 \ + --clip-type "vl_new" --add-time-attn \ + --do_train \ + --lock-text --lock-image --text-type "mix" \ + --init-temp 0.07 --learn-temp \ + --model "ViT-L-14" --cache-dir ${CACHE_DIR} \ + --lr 1e-4 --coef-lr 1 \ + --beta1 0.9 --beta2 0.98 --wd 0.2 --eps 1e-6 \ + --num-frames 8 --tube-size 1 --force-patch-dropout 0.3 \ + --epochs 6 --batch-size 8 --accum-freq 2 --warmup 2000 \ + --precision "amp" --workers 10 --video-decode-backend "decord" \ + --save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume "latest" \ + --do_eval \ + --val_vl_ret_data "msrvtt" "msvd" "activity" "didemo" \ No newline at end of file diff --git a/t_cls/cp_zero_shot_metadata.py b/t_cls/cp_zero_shot_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..87dfd0bd13f49fdfc7729b7d2a537c6b3afb9318 --- /dev/null +++ b/t_cls/cp_zero_shot_metadata.py @@ -0,0 +1,115 @@ +import os + +import pandas as pd + +OPENAI_IMAGENET_TEMPLATES = ( + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +) + +# a much smaller subset of above prompts +# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +SIMPLE_IMAGENET_TEMPLATES = ( + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +) + +CLASSNAMES = { + 'LLVIP': ( + "background", "people" + ), + 'FLIRV1': ( + "bicycle", "car", "dog", "person" + ), + 'FLIRV2': ( + "bike", "bus", "car or pick-up trucks or vans", "hydrant", "traffic light", "motor", "construction equipment or trailers", + "person", "sign", "skateboard", "stroller or pram", "semi truck or freight truck" + ), + 'LSOTB': ( + "airplane", "badger", "bat", "bird", "boat", "bus", "car", "cat", "cow", "coyote", "deer", "dog", + "drone", "fox", "helicopter", "hog", "leopard", "motobike", "person", "truck" + ) +} diff --git a/t_cls/datasets.py b/t_cls/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..eac021b6e0241af775a8cf9a2534c09965d5620e --- /dev/null +++ b/t_cls/datasets.py @@ -0,0 +1,20 @@ +import cv2 +import torch + +from data.build_datasets import DataInfo +from data.process_thermal import get_thermal_transform +from torchvision import datasets + +def get_thermal_dataset(args): + data_path = args.thermal_data_path + transform = get_thermal_transform(args) + dataset = datasets.ImageFolder(data_path, transform=transform) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=None, + ) + + return DataInfo(dataloader=dataloader, sampler=None) diff --git a/t_cls/precision.py b/t_cls/precision.py new file mode 100644 index 0000000000000000000000000000000000000000..a63b92256518d13afd57261df1568e26b1622201 --- /dev/null +++ b/t_cls/precision.py @@ -0,0 +1,12 @@ +import torch +from contextlib import suppress + + +def get_autocast(precision): + if precision == 'amp': + return torch.cuda.amp.autocast + elif precision == 'amp_bfloat16' or precision == 'amp_bf16': + # amp_bfloat16 is more stable than amp float16 for clip training + return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return suppress diff --git a/t_cls/zero_shot.py b/t_cls/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6a6054cde63e3d2b15435636cf8562e856a6a9 --- /dev/null +++ b/t_cls/zero_shot.py @@ -0,0 +1,95 @@ +import logging + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from open_clip import get_input_dtype, get_tokenizer +from open_clip.factory import HF_HUB_PREFIX +from .precision import get_autocast +from .zero_shot_classifier import build_zero_shot_classifier +from .zero_shot_metadata import CLASSNAMES, OPENAI_IMAGENET_TEMPLATES + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] + + +def run(model, classifier, dataloader, args): + autocast = get_autocast(args.precision) + input_dtype = get_input_dtype(args.precision) + + with torch.no_grad(): + top1, top5, n = 0., 0., 0. + for images, target in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(device=args.device, dtype=input_dtype) + images = images.unsqueeze(2) + target = target.to(args.device) + + with autocast(): + # predict + output = model(image=images) + image_features = output['image_features'] if isinstance(output, dict) else output[0] + logits = 100. * image_features @ classifier + + # measure accuracy + if args.val_t_cls_data == 'LLVIP' or args.val_t_cls_data == 'FLIRV1': + # if args.val_t_cls_data == 'LLVIP': + acc1, acc5 = accuracy(logits, target, topk=(1, ))[0], 0 + else: + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = (top1 / n) + top5 = (top5 / n) + return top1, top5 + + +def zero_shot_eval(model, data, epoch, args): + temp_val_t_cls_data = args.val_t_cls_data + args.val_t_cls_data = list(data.keys()) + assert len(args.val_t_cls_data) == 1 + args.val_t_cls_data = args.val_t_cls_data[0] + + if args.val_t_cls_data not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + if args.distributed and not args.horovod: + model = model.module + + logging.info(f'Starting zero-shot {args.val_t_cls_data.upper()}.') + + logging.info('Building zero-shot classifier') + autocast = get_autocast(args.precision) + with autocast(): + tokenizer = get_tokenizer(HF_HUB_PREFIX+args.model, cache_dir=args.cache_dir) + # tokenizer = get_tokenizer("ViT-L-14") + classifier = build_zero_shot_classifier( + model, + tokenizer=tokenizer, + classnames=CLASSNAMES[args.val_t_cls_data], + templates=OPENAI_IMAGENET_TEMPLATES, + num_classes_per_batch=2 if args.val_t_cls_data == 'LLVIP' or args.val_t_cls_data == 'FLIRV1' else 10, ############# + # num_classes_per_batch=2, ############# + device=args.device, + use_tqdm=True, + ) + + logging.info('Using classifier') + results = {} + if args.val_t_cls_data in data: + top1, top5 = run(model, classifier, data[args.val_t_cls_data].dataloader, args) + results[f'{args.val_t_cls_data}-zeroshot-val-top1'] = top1 + if args.val_t_cls_data != 'LLVIP' and args.val_t_cls_data != 'FLIRV1': + results[f'{args.val_t_cls_data}-zeroshot-val-top5'] = top5 + + logging.info(f'Finished zero-shot {args.val_t_cls_data.upper()}.') + args.val_t_cls_data = temp_val_t_cls_data + return results diff --git a/t_cls/zero_shot_classifier.py b/t_cls/zero_shot_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a5267cea4119994e30bb4830a6744cf25bdbaf --- /dev/null +++ b/t_cls/zero_shot_classifier.py @@ -0,0 +1,111 @@ +from functools import partial +from itertools import islice +from typing import Callable, List, Optional, Sequence, Union + +import torch +import torch.nn.functional as F + + +def batched(iterable, n): + """Batch data into lists of length *n*. The last batch may be shorter. + NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch + + +def build_zero_shot_classifier( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + num_classes_per_batch: Optional[int] = 10, + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names in batches + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + num_classes_per_batch: The number of classes to batch together in each forward, all if None + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + use_format = isinstance(templates[0], str) + num_templates = len(templates) + num_classes = len(classnames) + if use_tqdm: + import tqdm + num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) + iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) + else: + iter_wrap = iter + + def _process_batch(batch_classnames): + num_batch_classes = len(batch_classnames) + texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] + input_ids, attention_mask = tokenizer(texts) + input_ids, attention_mask = input_ids.to(device), attention_mask.to(device) + class_embeddings = F.normalize(model.encode_text(input_ids, attention_mask), dim=-1) + class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) + class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) + class_embeddings = class_embeddings.T + return class_embeddings + + with torch.no_grad(): + if num_classes_per_batch: + batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] + zeroshot_weights = torch.cat(batched_embeds, dim=1) + else: + zeroshot_weights = _process_batch(classnames) + return zeroshot_weights + + +def build_zero_shot_classifier_legacy( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names 1 by 1 + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + if use_tqdm: + import tqdm + iter_wrap = tqdm.tqdm + else: + iter_wrap = iter + + use_format = isinstance(templates[0], str) + + with torch.no_grad(): + zeroshot_weights = [] + for classname in iter_wrap(classnames): + texts = [template.format(classname) if use_format else template(classname) for template in templates] + texts = tokenizer(texts).to(device) # tokenize + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) + + return zeroshot_weights + diff --git a/t_cls/zero_shot_metadata.py b/t_cls/zero_shot_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..105281ac8eb3ed7189c9bb55b7b904157d4cc5a9 --- /dev/null +++ b/t_cls/zero_shot_metadata.py @@ -0,0 +1,232 @@ +# import os +# +# import pandas as pd +# +# OPENAI_IMAGENET_TEMPLATES = ( +# lambda c: f'a bad thermal infrared photo of a {c}.', +# lambda c: f'a thermal infrared photo of many {c}.', +# lambda c: f'a sculpture of a {c}.', +# lambda c: f'a thermal infrared photo of the hard to see {c}.', +# lambda c: f'a low resolution thermal infrared photo of the {c}.', +# lambda c: f'a rendering of a {c}.', +# lambda c: f'graffiti of a {c}.', +# lambda c: f'a bad thermal infrared photo of the {c}.', +# lambda c: f'a cropped thermal infrared photo of the {c}.', +# lambda c: f'a tattoo of a {c}.', +# lambda c: f'the embroidered {c}.', +# lambda c: f'a thermal infrared photo of a hard to see {c}.', +# lambda c: f'a bright thermal infrared photo of a {c}.', +# lambda c: f'a thermal infrared photo of a clean {c}.', +# lambda c: f'a thermal infrared photo of a dirty {c}.', +# lambda c: f'a dark thermal infrared photo of the {c}.', +# lambda c: f'a drawing of a {c}.', +# lambda c: f'a thermal infrared photo of my {c}.', +# lambda c: f'the plastic {c}.', +# lambda c: f'a thermal infrared photo of the cool {c}.', +# lambda c: f'a close-up thermal infrared photo of a {c}.', +# lambda c: f'a black and white thermal infrared photo of the {c}.', +# lambda c: f'a painting of the {c}.', +# lambda c: f'a painting of a {c}.', +# lambda c: f'a pixelated thermal infrared photo of the {c}.', +# lambda c: f'a sculpture of the {c}.', +# lambda c: f'a bright thermal infrared photo of the {c}.', +# lambda c: f'a cropped thermal infrared photo of a {c}.', +# lambda c: f'a plastic {c}.', +# lambda c: f'a thermal infrared photo of the dirty {c}.', +# lambda c: f'a jpeg corrupted thermal infrared photo of a {c}.', +# lambda c: f'a blurry thermal infrared photo of the {c}.', +# lambda c: f'a thermal infrared photo of the {c}.', +# lambda c: f'a good thermal infrared photo of the {c}.', +# lambda c: f'a rendering of the {c}.', +# lambda c: f'a {c} in a video game.', +# lambda c: f'a thermal infrared photo of one {c}.', +# lambda c: f'a doodle of a {c}.', +# lambda c: f'a close-up thermal infrared photo of the {c}.', +# lambda c: f'a thermal infrared photo of a {c}.', +# lambda c: f'the origami {c}.', +# lambda c: f'the {c} in a video game.', +# lambda c: f'a sketch of a {c}.', +# lambda c: f'a doodle of the {c}.', +# lambda c: f'a origami {c}.', +# lambda c: f'a low resolution thermal infrared photo of a {c}.', +# lambda c: f'the toy {c}.', +# lambda c: f'a rendition of the {c}.', +# lambda c: f'a thermal infrared photo of the clean {c}.', +# lambda c: f'a thermal infrared photo of a large {c}.', +# lambda c: f'a rendition of a {c}.', +# lambda c: f'a thermal infrared photo of a nice {c}.', +# lambda c: f'a thermal infrared photo of a weird {c}.', +# lambda c: f'a blurry thermal infrared photo of a {c}.', +# lambda c: f'a cartoon {c}.', +# lambda c: f'art of a {c}.', +# lambda c: f'a sketch of the {c}.', +# lambda c: f'a embroidered {c}.', +# lambda c: f'a pixelated thermal infrared photo of a {c}.', +# lambda c: f'itap of the {c}.', +# lambda c: f'a jpeg corrupted thermal infrared photo of the {c}.', +# lambda c: f'a good thermal infrared photo of a {c}.', +# lambda c: f'a plushie {c}.', +# lambda c: f'a thermal infrared photo of the nice {c}.', +# lambda c: f'a thermal infrared photo of the small {c}.', +# lambda c: f'a thermal infrared photo of the weird {c}.', +# lambda c: f'the cartoon {c}.', +# lambda c: f'art of the {c}.', +# lambda c: f'a drawing of the {c}.', +# lambda c: f'a thermal infrared photo of the large {c}.', +# lambda c: f'a black and white thermal infrared photo of a {c}.', +# lambda c: f'the plushie {c}.', +# lambda c: f'a dark thermal infrared photo of a {c}.', +# lambda c: f'itap of a {c}.', +# lambda c: f'graffiti of the {c}.', +# lambda c: f'a toy {c}.', +# lambda c: f'itap of my {c}.', +# lambda c: f'a thermal infrared photo of a cool {c}.', +# lambda c: f'a thermal infrared photo of a small {c}.', +# lambda c: f'a tattoo of the {c}.', +# ) +# +# # a much smaller subset of above prompts +# # from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +# SIMPLE_IMAGENET_TEMPLATES = ( +# lambda c: f'itap of a {c}.', +# lambda c: f'a bad thermal infrared photo of the {c}.', +# lambda c: f'a origami {c}.', +# lambda c: f'a thermal infrared photo of the large {c}.', +# lambda c: f'a {c} in a video game.', +# lambda c: f'art of the {c}.', +# lambda c: f'a thermal infrared photo of the small {c}.', +# ) +# +# CLASSNAMES = { +# 'LLVIP': ( +# "background", "people" +# ), +# 'FLIRV1': ( +# "bicycle", "car", "dog", "person" +# ), +# 'FLIRV2': ( +# "bike", "bus", "car or pick-up trucks or vans", "hydrant", "traffic light", "motor", "construction equipment or trailers", +# "person", "sign", "skateboard", "stroller or pram", "semi truck or freight truck" +# ), +# 'LSOTB': ( +# "airplane", "badger", "bat", "bird", "boat", "bus", "car", "cat", "cow", "coyote", "deer", "dog", +# "drone", "fox", "helicopter", "hog", "leopard", "motobike", "person", "truck" +# ) +# } + + +import os + +import pandas as pd + +OPENAI_IMAGENET_TEMPLATES = ( + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +) + +# a much smaller subset of above prompts +# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +SIMPLE_IMAGENET_TEMPLATES = ( + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +) + +CLASSNAMES = { + 'LLVIP': ( + "background", "people" + ), + 'FLIRV1': ( + "bicycle", "car", "dog", "person" + ), + 'FLIRV2': ( + "bike", "bus", "car or pick-up trucks or vans", "hydrant", "traffic light", "motor", "construction equipment or trailers", + "person", "sign", "skateboard", "stroller or pram", "semi truck or freight truck" + ), + 'LSOTB': ( + "airplane", "badger", "bat", "bird", "boat", "bus", "car", "cat", "cow", "coyote", "deer", "dog", + "drone", "fox", "helicopter", "hog", "leopard", "motobike", "person", "truck" + ) +} diff --git a/t_cls/zeroshot_cls.py b/t_cls/zeroshot_cls.py new file mode 100644 index 0000000000000000000000000000000000000000..aac2222a79c47293bf409aa38f1d88f10ed4b024 --- /dev/null +++ b/t_cls/zeroshot_cls.py @@ -0,0 +1,47 @@ + +import json +import logging +import os +from training.distributed import is_master +from .zero_shot import zero_shot_eval + +try: + import wandb +except ImportError: + wandb = None + + + +def evaluate_t_cls(model, data, epoch, args, tb_writer=None): + metrics = {} + if not is_master(args): + return metrics + model.eval() + + zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + metrics.update(zero_shot_metrics) + + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) + ) + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/t_cls/{args.val_t_cls_data[0].lower()}/{name}", val, epoch) + args.t_cls_output_dir = os.path.join(args.log_base_path, f't_cls/{args.val_t_cls_data[0].lower()}') + os.makedirs(args.t_cls_output_dir, exist_ok=True) + with open(os.path.join(args.t_cls_output_dir, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, 'Please install wandb.' + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, 'epoch': epoch}) + + return metrics diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..ae9ae4008e20088a16d021067aafe1a6ddee1d62 --- /dev/null +++ b/train.py @@ -0,0 +1,256 @@ +import json +import logging +import math +import os +import time + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.parallel.distributed import DistributedDataParallel + +from training.distributed import is_master +from training.precision import get_autocast + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import get_input_dtype, CLIP, CustomTextCLIP + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def postprocess_clip_output(model_out): + return { + "image_features": model_out[0], + "text_features": model_out[1], + "logit_scale": model_out[2] + } + +def unwrap_model(model): + if hasattr(model, 'module'): + return model.module + else: + return model + + +def backward(total_loss, scaler): + if scaler is not None: + scaler.scale(total_loss).backward() + else: + total_loss.backward() + + +def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None): + device = torch.device(args.device) + autocast = get_autocast(args.precision) + input_dtype = get_input_dtype(args.precision) + + + model.train() + if args.distill: + dist_model.eval() + + data[f'{args.clip_type}_pt'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch + dataloader = data[f'{args.clip_type}_pt'].dataloader + num_batches_per_epoch = dataloader.num_batches // args.accum_freq + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + if args.accum_freq > 1: + accum_images, accum_input_ids, accum_attention_mask, accum_features = [], [], [], {} + + losses_m = {} + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + for i, batch in enumerate(dataloader): + i_accum = i // args.accum_freq + step = num_batches_per_epoch * epoch + i_accum + + if not args.skip_scheduler: + scheduler(step) + + images, input_ids, attention_mask = batch + images = images.to(device=device, dtype=input_dtype, non_blocking=True) + input_ids = input_ids.to(device=device, non_blocking=True) + attention_mask = attention_mask.to(device=device, non_blocking=True) + + data_time_m.update(time.time() - end) + optimizer.zero_grad() + + if args.accum_freq == 1: + with autocast(): + model_out = model(images, input_ids, attention_mask) + logit_scale = model_out["logit_scale"] + if args.distill: + with torch.no_grad(): + dist_model_out = dist_model(images, input_ids, attention_mask) + model_out.update({f'dist_{k}' : v for k, v in dist_model_out.items()}) + losses = loss(**model_out, output_dict=True) + + total_loss = sum(losses.values()) + losses["loss"] = total_loss + + backward(total_loss, scaler) + else: + # First, cache the features without any gradient tracking. + with torch.no_grad(): + with autocast(): + model_out = model(images, input_ids, attention_mask) + model_out.pop("logit_scale") + for key, val in model_out.items(): + if key in accum_features: + accum_features[key].append(val) + else: + accum_features[key] = [val] + + accum_images.append(images) + accum_input_ids.append(input_ids) + accum_attention_mask.append(attention_mask) + + # If (i + 1) % accum_freq is not zero, move on to the next batch. + if ((i + 1) % args.accum_freq) > 0: + # FIXME this makes data time logging unreliable when accumulating + continue + + # Now, ready to take gradients for the last accum_freq batches. + # Re-do the forward pass for those batches, and use the cached features from the other batches as negatives. + # Call backwards each time, but only step optimizer at the end. + optimizer.zero_grad() + for j in range(args.accum_freq): + images = accum_images[j] + input_ids = accum_input_ids[j] + attention_mask = accum_attention_mask[j] + with autocast(): + model_out = model(images, input_ids, attention_mask) + logit_scale = model_out.pop("logit_scale") + inputs = {} + for key, val in accum_features.items(): + accumulated = accum_features[key] + inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:]) + losses = loss(**inputs, logit_scale=logit_scale, output_dict=True) + del inputs + total_loss = sum(losses.values()) + losses["loss"] = total_loss + backward(total_loss, scaler) + + if scaler is not None: + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + if args.grad_clip_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + if args.grad_clip_norm is not None: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + scaler.step(optimizer) + scaler.update() + else: + if args.grad_clip_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + optimizer.step() + + # reset gradient accum, if enabled + if args.accum_freq > 1: + accum_images, accum_input_ids, accum_attention_mask, accum_features = [], [], [], {} + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i_accum + 1 + if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch): + batch_size = len(images) + num_samples = batch_count * batch_size * args.accum_freq * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + for key, val in losses.items(): + if key not in losses_m: + losses_m[key] = AverageMeter() + losses_m[key].update(val.item(), batch_size) + + logit_scale_scalar = logit_scale.item() + # if args.add_time_attn: + # if hasattr(model, 'module'): + # t_gate = [[F.sigmoid(m.t_attn_gate).detach().item(), F.sigmoid(m.t_ffn_gate).detach().item()] for m in model.module.vision_model.encoder.layers] + # else: + # t_gate = [[F.sigmoid(m.t_attn_gate).detach().item(), F.sigmoid(m.t_ffn_gate).detach().item()] for m in model.vision_model.encoder.layers] + # t_attn_gate, t_ffn_gate = list(zip(*t_gate)) + loss_log = " ".join( + [ + f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" + for loss_name, loss_m in losses_m.items() + ] + ) + samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val + samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val + # if args.add_time_attn: + # logging.info( + # f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + # f"Data (t): {data_time_m.avg:.3f} " + # f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " + # f"LR: {optimizer.param_groups[0]['lr']:5f} " + # f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log + + # f"\nt_attn_gate: {[round(i, 2) for i in t_attn_gate]}\nt_ffn_gate: {[round(i, 2) for i in t_ffn_gate]}\n" + # ) + # else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log + ) + + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "samples_per_second": samples_per_second, + "samples_per_second_per_gpu": samples_per_second_per_gpu, + "scale": logit_scale_scalar, + "lr": optimizer.param_groups[0]["lr"] + } + log_data.update({name:val.val for name,val in losses_m.items()}) + # if args.add_time_attn: + # log_data.update({f'layer_{i}_t_attn_gate': attn for i, attn in enumerate(t_attn_gate)}) + # log_data.update({f'layer_{i}_t_ffn_gate': ffn for i, ffn in enumerate(t_ffn_gate)}) + + for name, val in log_data.items(): + name = "train/" + name + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, 'Please install wandb.' + wandb.log({name: val, 'step': step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for diff --git a/training/.gitignore b/training/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..333c1e910a3e2bef1b9d0d4587392627d8388974 --- /dev/null +++ b/training/.gitignore @@ -0,0 +1 @@ +logs/ diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/training/__pycache__/__init__.cpython-38.pyc b/training/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70c0640b16c7ae5ff176787ed70a66afd3eda7b4 Binary files /dev/null and b/training/__pycache__/__init__.cpython-38.pyc differ diff --git a/training/__pycache__/distributed.cpython-38.pyc b/training/__pycache__/distributed.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfa1748671b4d286b8681cc9036d22732d8018c8 Binary files /dev/null and b/training/__pycache__/distributed.cpython-38.pyc differ diff --git a/training/data.py b/training/data.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed076d96a34a641d3841e2a43221e8ba7e6900f --- /dev/null +++ b/training/data.py @@ -0,0 +1,563 @@ +import ast +import json +import logging +import math +import os +import random +import sys +import braceexpand +from dataclasses import dataclass +from multiprocessing import Value + +import numpy as np +import pandas as pd +import torch +import torchvision.datasets as datasets +import webdataset as wds +from PIL import Image +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info +from torch.utils.data.distributed import DistributedSampler +from webdataset.filters import _shuffle +from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +class CsvDataset(Dataset): + def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None): + logging.debug(f'Loading csv data from {input_filename}.') + df = pd.read_csv(input_filename, sep=sep) + + self.images = df[img_key].tolist() + self.captions = df[caption_key].tolist() + self.transforms = transforms + logging.debug('Done loading data.') + + self.tokenize = tokenizer + + def __len__(self): + return len(self.captions) + + def __getitem__(self, idx): + images = self.transforms(Image.open(str(self.images[idx]))) + texts = self.tokenize([str(self.captions[idx])])[0] + return images, texts + + +class SharedEpoch: + def __init__(self, epoch: int = 0): + self.shared_epoch = Value('i', epoch) + + def set_value(self, epoch): + self.shared_epoch.value = epoch + + def get_value(self): + return self.shared_epoch.value + + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler = None + shared_epoch: SharedEpoch = None + + def set_epoch(self, epoch): + if self.shared_epoch is not None: + self.shared_epoch.set_value(epoch) + if self.sampler is not None and isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + + +def expand_urls(urls, weights=None): + if weights is None: + expanded_urls = wds.shardlists.expand_urls(urls) + return expanded_urls, None + if isinstance(urls, str): + urllist = urls.split("::") + weights = weights.split('::') + assert len(weights) == len(urllist),\ + f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." + weights = [float(weight) for weight in weights] + all_urls, all_weights = [], [] + for url, weight in zip(urllist, weights): + expanded_url = list(braceexpand.braceexpand(url)) + expanded_weights = [weight for _ in expanded_url] + all_urls.extend(expanded_url) + all_weights.extend(expanded_weights) + return all_urls, all_weights + else: + all_urls = list(urls) + return all_urls, weights + + +def get_dataset_size(shards): + shards_list, _ = expand_urls(shards) + dir_path = os.path.dirname(shards_list[0]) + sizes_filename = os.path.join(dir_path, 'sizes.json') + len_filename = os.path.join(dir_path, '__len__') + if os.path.exists(sizes_filename): + sizes = json.load(open(sizes_filename, 'r')) + total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) + elif os.path.exists(len_filename): + # FIXME this used to be eval(open(...)) but that seemed rather unsafe + total_size = ast.literal_eval(open(len_filename, 'r').read()) + else: + total_size = None # num samples undefined + # some common dataset sizes (at time of authors last download) + # CC3M (train): 2905954 + # CC12M: 10968539 + # LAION-400M: 407332084 + # LAION-2B (english): 2170337258 + num_shards = len(shards_list) + return total_size, num_shards + + +def get_imagenet(args, preprocess_fns, split): + assert split in ["train", "val", "v2"] + is_train = split == "train" + preprocess_train, preprocess_val = preprocess_fns + + if split == "v2": + from imagenetv2_pytorch import ImageNetV2Dataset + dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) + else: + if is_train: + data_path = args.imagenet_train + preprocess_fn = preprocess_train + else: + data_path = args.imagenet_val + preprocess_fn = preprocess_val + assert data_path + + dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) + + if is_train: + idxs = np.zeros(len(dataset.targets)) + target_array = np.array(dataset.targets) + k = 50 + for c in range(1000): + m = target_array == c + n = len(idxs[m]) + arr = np.zeros(n) + arr[:k] = 1 + np.random.shuffle(arr) + idxs[m] = arr + + idxs = idxs.astype('int') + sampler = SubsetRandomSampler(np.where(idxs)[0]) + else: + sampler = None + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=sampler, + ) + + return DataInfo(dataloader=dataloader, sampler=sampler) + + +def count_samples(dataloader): + os.environ["WDS_EPOCH"] = "0" + n_elements, n_batches = 0, 0 + for images, texts in dataloader: + n_batches += 1 + n_elements += len(images) + assert len(images) == len(texts) + return n_elements, n_batches + + +def filter_no_caption_or_no_image(sample): + has_caption = ('txt' in sample) + has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) + return has_caption and has_image + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, issue a warning, and continue.""" + logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') + return True + + +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_to_samples_nothrow(src, handler=log_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + + +def pytorch_worker_seed(increment=0): + """get dataloader worker seed from pytorch""" + worker_info = get_worker_info() + if worker_info is not None: + # favour using the seed already created for pytorch dataloader workers if it exists + seed = worker_info.seed + if increment: + # space out seed increments so they can't overlap across workers in different iterations + seed += increment * max(1, worker_info.num_workers) + return seed + # fallback to wds rank based seed + return wds.utils.pytorch_worker_seed() + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +class detshuffle2(wds.PipelineStage): + def __init__( + self, + bufsize=1000, + initial=100, + seed=0, + epoch=-1, + ): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + rng = random.Random() + if self.seed < 0: + # If seed is negative, we use the worker's seed, this will be different across all nodes/workers + seed = pytorch_worker_seed(epoch) + else: + # This seed to be deterministic AND the same across all nodes/workers in each epoch + seed = self.seed + epoch + rng.seed(seed) + return _shuffle(src, self.bufsize, self.initial, rng) + + +class ResampledShards2(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__( + self, + urls, + weights=None, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, + epoch=-1, + ): + """Sample shards from the shard list with replacement. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls, weights = expand_urls(urls, weights) + self.urls = urls + self.weights = weights + if self.weights is not None: + assert len(self.urls) == len(self.weights),\ + f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match." + assert isinstance(self.urls[0], str) + self.nshards = nshards + self.rng = random.Random() + self.worker_seed = worker_seed + self.deterministic = deterministic + self.epoch = epoch + + def __iter__(self): + """Return an iterator over the shards.""" + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + if self.deterministic: + # reset seed w/ epoch if deterministic + if self.worker_seed is None: + # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id + seed = pytorch_worker_seed(epoch) + else: + seed = self.worker_seed() + epoch + self.rng.seed(seed) + for _ in range(self.nshards): + if self.weights is None: + yield dict(url=self.rng.choice(self.urls)) + else: + yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]) + + +def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + resampled = getattr(args, 'dataset_resampled', False) and is_train + + num_shards = None + if is_train: + if args.train_num_samples is not None: + num_samples = args.train_num_samples + else: + num_samples, num_shards = get_dataset_size(input_shards) + if not num_samples: + raise RuntimeError( + 'Currently, the number of dataset samples must be specified for the training dataset. ' + 'Please specify it via `--train-num-samples` if no dataset length info is present.') + else: + # Eval will just exhaust the iterator if the size is not specified. + num_samples = args.val_num_samples or 0 + + shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + + if resampled: + pipeline = [ResampledShards2( + input_shards, + weights=args.train_data_upsampling_factors, + deterministic=True, + epoch=shared_epoch, + )] + else: + assert args.train_data_upsampling_factors is None,\ + "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)." + pipeline = [wds.SimpleShardList(input_shards)] + + # at this point we have an iterator over all the shards + if is_train: + if not resampled: + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ]) + pipeline.extend([ + # at this point, we have an iterator over the shards assigned to each worker at each node + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ]) + else: + pipeline.extend([ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + pipeline.extend([ + wds.select(filter_no_caption_or_no_image), + wds.decode("pilrgb", handler=log_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="txt"), + wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), + wds.to_tuple("image", "text"), + wds.batched(args.batch_size, partial=not is_train) + ]) + + dataset = wds.DataPipeline(*pipeline) + + if is_train: + if not resampled: + num_shards = num_shards or len(expand_urls(input_shards)[0]) + assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=args.workers > 0, + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + + +def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): + input_filename = args.train_data if is_train else args.val_data + assert input_filename + dataset = CsvDataset( + input_filename, + preprocess_fn, + img_key=args.csv_img_key, + caption_key=args.csv_caption_key, + sep=args.csv_separator, + tokenizer=tokenizer + ) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed and is_train else None + shuffle = is_train and sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +class SyntheticDataset(Dataset): + + def __init__( + self, + transform=None, + image_size=(224, 224), + caption="Dummy caption", + dataset_size=100, + tokenizer=None, + ): + self.transform = transform + self.image_size = image_size + self.caption = caption + self.image = Image.new('RGB', image_size) + self.dataset_size = dataset_size + + self.preprocess_txt = lambda text: tokenizer(text)[0] + + def __len__(self): + return self.dataset_size + + def __getitem__(self, idx): + if self.transform is not None: + image = self.transform(self.image) + return image, self.preprocess_txt(self.caption) + + +def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): + image_size = preprocess_fn.transforms[0].size + dataset = SyntheticDataset( + transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed and is_train else None + shuffle = is_train and sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_dataset_fn(data_path, dataset_type): + if dataset_type == "webdataset": + return get_wds_dataset + elif dataset_type == "csv": + return get_csv_dataset + elif dataset_type == "synthetic": + return get_synthetic_dataset + elif dataset_type == "auto": + ext = data_path.split('.')[-1] + if ext in ['csv', 'tsv']: + return get_csv_dataset + elif ext in ['tar']: + return get_wds_dataset + else: + raise ValueError( + f"Tried to figure out dataset type, but failed for extension {ext}.") + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + + +def get_data(args, preprocess_fns, epoch=0, tokenizer=None): + preprocess_train, preprocess_val = preprocess_fns + data = {} + + if args.train_data or args.dataset_type == "synthetic": + data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( + args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) + + if args.val_data: + data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( + args, preprocess_val, is_train=False, tokenizer=tokenizer) + + if args.imagenet_val is not None: + data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") + + if args.imagenet_v2 is not None: + data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") + + return data diff --git a/training/distributed.py b/training/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..6d158c210b37c6bd433edeab2b7f2cdb8da3f5af --- /dev/null +++ b/training/distributed.py @@ -0,0 +1,140 @@ +import datetime +import os + +import torch +import torch.distributed as dist + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def is_global_master(args): + return args.rank == 0 + + +def is_local_master(args): + return args.local_rank == 0 + + +def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + + +def is_using_horovod(): + # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set + # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... + ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] + pmi_vars = ["PMI_RANK", "PMI_SIZE"] + if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): + return True + else: + return False + + +def is_using_distributed(): + if 'WORLD_SIZE' in os.environ: + return int(os.environ['WORLD_SIZE']) > 1 + if 'SLURM_NTASKS' in os.environ: + return int(os.environ['SLURM_NTASKS']) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): + if v in os.environ: + local_rank = int(os.environ[v]) + break + global_rank = 0 + for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): + if v in os.environ: + global_rank = int(os.environ[v]) + break + world_size = 1 + for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): + if v in os.environ: + world_size = int(os.environ[v]) + break + + return local_rank, global_rank, world_size + + +def init_distributed_device(args): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + args.distributed = False + args.world_size = 1 + args.rank = 0 # global rank + args.local_rank = 0 + if args.horovod: + assert hvd is not None, "Horovod is not installed" + hvd.init() + args.local_rank = int(hvd.local_rank()) + args.rank = hvd.rank() + args.world_size = hvd.size() + args.distributed = True + os.environ['LOCAL_RANK'] = str(args.local_rank) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + elif is_using_distributed(): + if 'SLURM_PROCID' in os.environ: + # DDP via SLURM + args.local_rank, args.rank, args.world_size = world_info_from_env() + # SLURM var -> torch.distributed vars in case needed + os.environ['LOCAL_RANK'] = str(args.local_rank) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + torch.distributed.init_process_group( + backend='nccl', + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + timeout=datetime.timedelta(seconds=180) + ) + else: + # DDP via torchrun, torch.distributed.launch + args.local_rank, _, _ = world_info_from_env() + torch.distributed.init_process_group( + backend='nccl', + init_method=args.dist_url, + timeout=datetime.timedelta(seconds=180)) + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + args.distributed = True + + if torch.cuda.is_available(): + if args.distributed and not args.no_set_device_rank: + device = 'cuda:%d' % args.local_rank + else: + device = 'cuda:0' + torch.cuda.set_device(device) + else: + device = 'cpu' + args.device = device + device = torch.device(device) + return device + + +def broadcast_object(args, obj, src=0): + # broadcast a pickle-able python object from rank-0 to all ranks + if args.horovod: + return hvd.broadcast_object(obj, root_rank=src) + else: + if args.rank == src: + objects = [obj] + else: + objects = [None] + dist.broadcast_object_list(objects, src=src) + return objects[0] + + +def all_gather_object(args, obj, dst=0): + # gather a pickle-able python object across all ranks + if args.horovod: + return hvd.allgather_object(obj) + else: + objects = [None for _ in range(args.world_size)] + dist.all_gather_object(objects, obj) + return objects diff --git a/training/file_utils.py b/training/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..395cf7df0acc164c6851f17834d793f5852d4605 --- /dev/null +++ b/training/file_utils.py @@ -0,0 +1,83 @@ +import logging +import os +import multiprocessing +import subprocess +import time +import fsspec +import torch +from tqdm import tqdm + +def remote_sync_s3(local_dir, remote_dir): + # skip epoch_latest which can change during sync. + result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if result.returncode != 0: + logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") + return False + + logging.info(f"Successfully synced with S3 bucket") + return True + +def remote_sync_fsspec(local_dir, remote_dir): + # FIXME currently this is slow and not recommended. Look into speeding up. + a = fsspec.get_mapper(local_dir) + b = fsspec.get_mapper(remote_dir) + + for k in a: + # skip epoch_latest which can change during sync. + if 'epoch_latest.pt' in k: + continue + + logging.info(f'Attempting to sync {k}') + if k in b and len(a[k]) == len(b[k]): + logging.debug(f'Skipping remote sync for {k}.') + continue + + try: + logging.info(f'Successful sync for {k}.') + b[k] = a[k] + except Exception as e: + logging.info(f'Error during remote sync for {k}: {e}') + return False + + return True + +def remote_sync(local_dir, remote_dir, protocol): + logging.info('Starting remote sync.') + if protocol == 's3': + return remote_sync_s3(local_dir, remote_dir) + elif protocol == 'fsspec': + return remote_sync_fsspec(local_dir, remote_dir) + else: + logging.error('Remote protocol not known') + return False + +def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): + while True: + time.sleep(sync_every) + remote_sync(local_dir, remote_dir, protocol) + +def start_sync_process(sync_every, local_dir, remote_dir, protocol): + p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) + return p + +# Note: we are not currently using this save function. +def pt_save(pt_obj, file_path): + of = fsspec.open(file_path, "wb") + with of as f: + torch.save(pt_obj, file_path) + +def pt_load(file_path, map_location=None): + if file_path.startswith('s3'): + logging.info('Loading remote checkpoint, which may take a bit.') + of = fsspec.open(file_path, "rb") + with of as f: + out = torch.load(f, map_location=map_location) + return out + +def check_exists(file_path): + try: + with fsspec.open(file_path): + pass + except FileNotFoundError: + return False + return True diff --git a/training/logger.py b/training/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9abed92568d459cbc8d6094ae3901935d89621 --- /dev/null +++ b/training/logger.py @@ -0,0 +1,26 @@ +import logging + + +def setup_logging(log_file, level, include_host=False): + if include_host: + import socket + hostname = socket.gethostname() + formatter = logging.Formatter( + f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') + else: + formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') + + logging.root.setLevel(level) + loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] + for logger in loggers: + logger.setLevel(level) + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logging.root.addHandler(stream_handler) + + if log_file: + file_handler = logging.FileHandler(filename=log_file) + file_handler.setFormatter(formatter) + logging.root.addHandler(file_handler) + diff --git a/training/main.py b/training/main.py new file mode 100644 index 0000000000000000000000000000000000000000..2e90ce67d4004bfc38d59238db67bdc8ce33600d --- /dev/null +++ b/training/main.py @@ -0,0 +1,490 @@ +import glob +import logging +import os +import re +import subprocess +import sys +import random +from datetime import datetime + +import numpy as np +import torch +from torch import optim +from torch.cuda.amp import GradScaler + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss +from training.data import get_data +from training.distributed import is_master, init_distributed_device, broadcast_object +from training.logger import setup_logging +from training.params import parse_args +from training.scheduler import cosine_lr, const_lr, const_lr_cooldown +from training.train import train_one_epoch, evaluate +from training.file_utils import pt_load, check_exists, start_sync_process, remote_sync + + +LATEST_CHECKPOINT_NAME = "epoch_latest.pt" + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def get_latest_checkpoint(path: str, remote : bool): + # as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders + if remote: + result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + print(result) + if result.returncode == 1: + return None + checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]] + else: + checkpoints = glob.glob(path + '**/*.pt', recursive=True) + if checkpoints: + checkpoints = sorted(checkpoints, key=natural_key) + return checkpoints[-1] + return None + + +def main(args): + args = parse_args(args) + + if torch.cuda.is_available(): + # This enables tf32 on Ampere GPUs which is only 8% slower than + # float16 and almost as accurate as float32 + # This was a default in pytorch until 1.12 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + # fully initialize distributed device environment + device = init_distributed_device(args) + + # get the name of the experiments + if args.name is None: + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + model_name_safe = args.model.replace('/', '-') + date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + if args.distributed: + # sync date_str from master to all ranks + date_str = broadcast_object(args, date_str) + args.name = '-'.join([ + date_str, + f"model_{model_name_safe}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ]) + + resume_latest = args.resume == 'latest' + log_base_path = os.path.join(args.logs, args.name) + args.log_path = None + if is_master(args, local=args.log_local): + os.makedirs(log_base_path, exist_ok=True) + log_filename = f'out-{args.rank}' if args.log_local else 'out.log' + args.log_path = os.path.join(log_base_path, log_filename) + if os.path.exists(args.log_path) and not resume_latest: + print( + "Error. Experiment already exists. Use --name {} to specify a new experiment." + ) + return -1 + + # Setup text logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # Setup wandb, tensorboard, checkpoint logging + args.wandb = 'wandb' in args.report_to or 'all' in args.report_to + args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to + args.checkpoint_path = os.path.join(log_base_path, "checkpoints") + if is_master(args): + args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else '' + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = '' + + if resume_latest: + resume_from = None + checkpoint_path = args.checkpoint_path + # If using remote_sync, need to check the remote instead of the local checkpoints folder. + if args.remote_sync is not None: + checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints") + if args.save_most_recent: + print('Error. Cannot use save-most-recent with remote_sync and resume latest.') + return -1 + if args.remote_sync_protocol != 's3': + print('Error. Sync protocol not supported when using resume latest.') + return -1 + if is_master(args): + # Checking for existing checkpoint via master rank only. It is possible for + # different rank processes to see different files if a shared file-system is under + # stress, however it's very difficult to fully work around such situations. + if args.save_most_recent: + # if --save-most-recent flag is set, look for latest at a fixed filename + resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME) + if not os.path.exists(resume_from): + # If no latest checkpoint has been saved yet, don't try to resume + resume_from = None + else: + # otherwise, list checkpoint dir contents and pick the newest checkpoint + resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None) + if resume_from: + logging.info(f'Found latest resume checkpoint at {resume_from}.') + else: + logging.info(f'No latest resume checkpoint found in {checkpoint_path}.') + if args.distributed: + # sync found checkpoint path to all ranks + resume_from = broadcast_object(args, resume_from) + args.resume = resume_from + + if args.copy_codebase: + copy_codebase(args) + + # start the sync proces if remote-sync is not None + remote_sync_process = None + if is_master(args) and args.remote_sync is not None: + # first make sure it works + result = remote_sync( + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + if result: + logging.info('remote sync successful.') + else: + logging.info('Error: remote sync failed. Exiting.') + return -1 + # if all looks good, start a process to do this every args.remote_sync_frequency seconds + remote_sync_process = start_sync_process( + args.remote_sync_frequency, + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + remote_sync_process.start() + + if args.precision == 'fp16': + logging.warning( + 'It is recommended to use AMP mixed-precision instead of FP16. ' + 'FP16 support needs further verification and tuning, especially for train.') + + if args.horovod: + logging.info( + f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.' + f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') + elif args.distributed: + logging.info( + f'Running in distributed mode with multiple processes. Device: {args.device}.' + f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') + else: + logging.info(f'Running with a single process. Device {args.device}.') + + dist_model = None + args.distill = args.distill_model is not None and args.distill_pretrained is not None + if args.distill: + #FIXME: support distillation with grad accum. + assert args.accum_freq == 1 + #FIXME: support distillation with coca. + assert 'coca' not in args.model.lower() + + if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1: + # arg is nargs, single (square) image size list -> int + args.force_image_size = args.force_image_size[0] + random_seed(args.seed, 0) + model, preprocess_train, preprocess_val = create_model_and_transforms( + args.model, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + force_custom_text=args.force_custom_text, + force_patch_dropout=args.force_patch_dropout, + force_image_size=args.force_image_size, + pretrained_image=args.pretrained_image, + image_mean=args.image_mean, + image_std=args.image_std, + aug_cfg=args.aug_cfg, + output_dict=True, + ) + if args.distill: + # FIXME: currenlty assumes the model your distilling from has the same tokenizer & transforms. + dist_model, _, _ = create_model_and_transforms( + args.distill_model, + args.distill_pretrained, + device=device, + precision=args.precision, + output_dict=True, + ) + if args.use_bnb_linear is not None: + print('=> using a layer from bitsandbytes.\n' + ' this is an experimental feature which requires two extra pip installs\n' + ' pip install bitsandbytes triton' + ' please make sure to use triton 2.0.0') + import bitsandbytes as bnb + from open_clip.utils import replace_linear + print(f'=> replacing linear layers with {args.use_bnb_linear}') + linear_replacement_cls = getattr(bnb.nn.triton_based_modules, args.use_bnb_linear) + replace_linear(model, linear_replacement_cls) + model = model.to(device) + + random_seed(args.seed, args.rank) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if args.lock_image: + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + model.lock_image_tower( + unlocked_groups=args.lock_image_unlocked_groups, + freeze_bn_stats=args.lock_image_freeze_bn_stats) + if args.lock_text: + model.lock_text_tower( + unlocked_layers=args.lock_text_unlocked_layers, + freeze_layer_norm=args.lock_text_freeze_layer_norm) + + if args.grad_checkpointing: + model.set_grad_checkpointing() + + if is_master(args): + logging.info("Model:") + # logging.info(f"{str(model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args['static_graph'] = True + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) + + if args.distill: + dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args) + + # create optimizer and scaler + optimizer = None + scaler = None + + if args.train_data or args.dataset_type == "synthetic": + assert not args.trace, 'Cannot train with traced model' + + exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n + include = lambda n, p: not exclude(n, p) + + named_parameters = list(model.named_parameters()) + gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + + optimizer = optim.AdamW( + [ + {"params": gain_or_bias_params, "weight_decay": 0.}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + ) + if args.horovod: + optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + checkpoint = pt_load(args.resume, map_location='cpu') + if 'epoch' in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith('module'): + sd = {k[len('module.'):]: v for k, v in sd.items()} + model.load_state_dict(sd) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and 'scaler' in checkpoint: + scaler.load_state_dict(checkpoint['scaler']) + logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") + + # initialize datasets + data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + assert len(data), 'At least one train or eval dataset must be specified.' + + # create scheduler if train + scheduler = None + if 'train' in data and optimizer is not None: + total_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs + if args.lr_scheduler == "cosine": + scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) + elif args.lr_scheduler == "const": + scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps) + elif args.lr_scheduler == "const-cooldown": + assert args.epochs_cooldown is not None,\ + "Please specify the number of cooldown epochs for this lr schedule." + cooldown_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs_cooldown + scheduler = const_lr_cooldown( + optimizer, args.lr, args.warmup, total_steps, + cooldown_steps, args.lr_cooldown_power, args.lr_cooldown_end) + else: + logging.error( + f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.') + exit(1) + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, 'Please install wandb.' + logging.debug('Starting wandb.') + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project=args.wandb_project_name, + name=args.name, + id=args.name, + notes=args.wandb_notes, + tags=[], + resume='auto' if args.resume == "latest" else None, + config=vars(args), + ) + if args.debug: + wandb.watch(model, log='all') + wandb.save(params_file) + logging.debug('Finished loading wandb.') + + if args.torchcompile: + logging.info('Compiling model...') + model = torch.compile(model) + + if 'train' not in data: + # If using int8, convert to inference mode. + if args.use_bnb_linear is not None: + from open_clip.utils import convert_int8_model_to_inference_mode + convert_int8_model_to_inference_mode(model) + # Evaluate. + evaluate(model, data, start_epoch, args, writer) + return + + loss = create_loss(args) + + for epoch in range(start_epoch, args.epochs): + if is_master(args): + logging.info(f'Start epoch {epoch}') + + train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer) + completed_epoch = epoch + 1 + + if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): + evaluate(model, data, completed_epoch, args, writer) + + # Saving checkpoints. + if args.save_logs: + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.delete_previous_checkpoint: + previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt") + if os.path.exists(previous_checkpoint): + os.remove(previous_checkpoint) + + if args.save_most_recent: + # try not to corrupt the latest checkpoint if save fails + tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt") + latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME) + torch.save(checkpoint_dict, tmp_save_path) + os.replace(tmp_save_path, latest_save_path) + + if args.wandb and is_master(args): + wandb.finish() + + # run a final sync. + if remote_sync_process is not None: + logging.info('Final remote sync.') + remote_sync_process.terminate() + result = remote_sync( + os.path.join(args.logs, args.name), + os.path.join(args.remote_sync, args.name), + args.remote_sync_protocol + ) + if result: + logging.info('Final remote sync successful.') + else: + logging.info('Final remote sync failed.') + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb')) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/training/params.py b/training/params.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f65d2b55ab70a7168ff147699b3ce278839c41 --- /dev/null +++ b/training/params.py @@ -0,0 +1,530 @@ +import argparse +import ast + + +def get_default_params(model_name): + # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) + model_name = model_name.lower() + if "vit" in model_name: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} + else: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} + + +class ParseKwargs(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + kw = {} + for value in values: + key, value = value.split('=') + try: + kw[key] = ast.literal_eval(value) + except ValueError: + kw[key] = str(value) # fallback to string (avoid need to escape on command line) + setattr(namespace, self.dest, kw) + + +def parse_args(args): + parser = argparse.ArgumentParser() + + ################################### + # my new params + parser.add_argument("--cache-dir", type=str, default='', help="",) + parser.add_argument("--languagebind_weight", type=str, default='', help="",) + parser.add_argument("--num-frames", type=int, default=8, help="",) + parser.add_argument("--tube-size", type=int, default=1, help="",) + parser.add_argument("--clip-type", type=str, default="", choices=['vl', 'al', 'dl', 'tl', 'vl_new'], help="",) + parser.add_argument("--text-type", type=str, default="chatgpt", help="'raw', 'ofa', 'mplug', 'polish_mplug'",) + parser.add_argument("--add-time-attn", default=False, action="store_true", help="") + parser.add_argument("--unlock-time-attn", default=False, action="store_true", help="") + parser.add_argument("--coef-lr", type=float, default=1e-4, help="") + parser.add_argument("--init-temp", type=float, default=0, help="",) + parser.add_argument("--local_rank", type=int, default=-1, help="",) + parser.add_argument("--learn-temp", default=False, action="store_true", help="") + parser.add_argument("--video-decode-backend", type=str, default="opencv", choices=['pytorchvideo', 'decord', 'opencv', 'imgs'], help="") + parser.add_argument("--do_train", action='store_true', help="Whether to run training.") + parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") + + ############################ + # LoRA + parser.add_argument("--convert_to_lora", action='store_true', help="Whether to run eval on the dev set.") + parser.add_argument('--lora_r', type=int, default=16, help='') + parser.add_argument('--lora_alpha', type=int, default=16, help='') + parser.add_argument('--lora_dropout', type=float, default=0.0, help='') + + ############################ + # depth classification + parser.add_argument('--val_d_cls_data', nargs='+', help="Point the dataset to finetune.") + parser.add_argument("--depth_data_path", default="", type=str, help="") + parser.add_argument("--max-depth", type=int, default=10, help="") + + ############################ + # thermal classification + parser.add_argument('--val_t_cls_data', nargs='+', help="Point the dataset to finetune.") + parser.add_argument("--thermal_data_path", default="", type=str, help="") + + ############################ + # audio classification + parser.add_argument('--use_audios', nargs='+', help="Point the dataset.") + parser.add_argument('--data_val', type=str, default='', help='') + parser.add_argument('--label_csv', type=str, default='', help='') + parser.add_argument('--val_a_cls_data', nargs='+', help="Point the dataset to finetune.") + parser.add_argument('--val_al_ret_data', nargs='+', help="Point the dataset to finetune.") + parser.add_argument('--num_mel_bins', type=int, default=128, help='') + parser.add_argument('--target_length', type=int, default=1024, help='') + parser.add_argument('--audio_sample_rate', type=int, default=16000, help='') + parser.add_argument('--audio_mean', type=float, default=-4.2677393, help='') + parser.add_argument('--audio_std', type=float, default=4.5689974, help='') + + ############################## + # video-text retrieval + parser.add_argument('--val_vl_ret_data', nargs='+', help="Point the dataset to finetune.") + parser.add_argument('--train_csv', type=str, default='data/.train.csv', help='') + parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='') + parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path') + parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path') + parser.add_argument('--eval_frame_order', type=int, default=0, choices=[0, 1, 2], help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.") + parser.add_argument('--feature_framerate', type=int, default=1, help='') + parser.add_argument('--slice_framepos', type=int, default=2, choices=[0, 1, 2], help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.") + parser.add_argument('--max_frames', type=int, default=8, help='') + parser.add_argument('--max_words', type=int, default=77, help='') + parser.add_argument('--batch_size_val', type=int, default=0, help='batch size eval') + parser.add_argument('--num_thread_reader', type=int, default=10, help='') + + ############################ + # video classification + parser.add_argument('--val_v_cls_data', nargs='+', help="Point the dataset to finetune.") + parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation') + parser.add_argument('--sparse_sample', default=False, action='store_true') + parser.add_argument('--data_set', default='Kinetics-400', choices=['Kinetics-400', 'Kinetics-600'], type=str, help='dataset') + parser.add_argument('--nb_classes', default=400, type=int, help='number of the classification types') + parser.add_argument('--video_data_path', default='/your/data/path/', type=str, help='dataset path') + parser.add_argument('--data_root', default='', type=str, help='dataset path root') + parser.add_argument('--input_size', default=224, type=int, help='images input size') + parser.add_argument('--short_side_size', type=int, default=224) + parser.add_argument('--test_num_segment', type=int, default=10) + parser.add_argument('--test_num_crop', type=int, default=3) + parser.add_argument('--sampling_rate', type=int, default=16) + parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob (default: 0.25)') + + ####################### + # origin open-clip params + parser.add_argument( + "--train-data", + type=str, + default=None, + help="Path to file(s) with training data. When using webdataset, multiple datasources can be combined using the `::` separator.", + ) + parser.add_argument( + "--train-data-upsampling-factors", + type=str, + default=None, + help=( + "When using multiple data sources with webdataset and sampling with replacement, this can be used to upsample specific data sources. " + "Similar to --train-data, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) " + "By default, datapoints are sampled uniformly regardless of the dataset sizes." + ) + ) + parser.add_argument( + "--val-data", + type=str, + default=None, + help="Path to file(s) with validation data", + ) + parser.add_argument( + "--train-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Required for webdataset if not available in info file.", + ) + parser.add_argument( + "--val-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Useful for webdataset if not available in info file.", + ) + parser.add_argument( + "--dataset-type", + choices=["webdataset", "json", "csv", "synthetic", "auto"], + default="auto", + help="Which type of dataset to process." + ) + parser.add_argument( + "--dataset-resampled", + default=False, + action="store_true", + help="Whether to use sampling with replacement for webdataset shard selection." + ) + parser.add_argument( + "--csv-separator", + type=str, + default="\t", + help="For csv-like datasets, which separator to use." + ) + parser.add_argument( + "--csv-img-key", + type=str, + default="filepath", + help="For csv-like datasets, the name of the key for the image paths." + ) + parser.add_argument( + "--csv-caption-key", + type=str, + default="title", + help="For csv-like datasets, the name of the key for the captions." + ) + parser.add_argument( + "--imagenet-val", + type=str, + default=None, + help="Path to imagenet val set for conducting zero shot evaluation.", + ) + parser.add_argument( + "--imagenet-v2", + type=str, + default=None, + help="Path to imagenet v2 for conducting zero shot evaluation.", + ) + parser.add_argument( + "--logs", + type=str, + default="./logs/", + help="Where to store tensorboard logs. Use None to avoid storing logs.", + ) + parser.add_argument( + "--log-local", + action="store_true", + default=False, + help="log files on local master, otherwise global master only.", + ) + parser.add_argument( + "--name", + type=str, + default=None, + help="Optional identifier for the experiment when storing logs. Otherwise use current time.", + ) + parser.add_argument( + "--workers", type=int, default=1, help="Number of dataloader workers per GPU." + ) + parser.add_argument( + "--batch-size", type=int, default=64, help="Batch size per GPU." + ) + parser.add_argument( + "--epochs", type=int, default=32, help="Number of epochs to train for." + ) + parser.add_argument( + "--epochs-cooldown", type=int, default=None, + help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards." + ) + parser.add_argument("--lr", type=float, default=None, help="Learning rate.") + parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") + parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") + parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") + parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") + parser.add_argument( + "--warmup", type=int, default=10000, help="Number of steps to warmup for." + ) + parser.add_argument( + "--use-bn-sync", + default=False, + action="store_true", + help="Whether to use batch norm sync.") + parser.add_argument( + "--skip-scheduler", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--lr-scheduler", + type=str, + default='cosine', + help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine", + ) + parser.add_argument( + "--lr-cooldown-end", type=float, default=0.0, + help="End learning rate for cooldown schedule. Default: 0" + ) + parser.add_argument( + "--lr-cooldown-power", type=float, default=1.0, + help="Power for polynomial cooldown schedule. Default: 1.0 (linear decay)" + ) + parser.add_argument( + "--save-frequency", type=int, default=1, help="How often to save checkpoints." + ) + parser.add_argument( + "--save-most-recent", + action="store_true", + default=False, + help="Always save the most recent model trained to epoch_latest.pt.", + ) + parser.add_argument( + "--zeroshot-frequency", type=int, default=1, help="How often to run zero shot." + ) + parser.add_argument( + "--val-frequency", type=int, default=1, help="How often to run evaluation with val data." + ) + parser.add_argument( + "--resume", + default=None, + type=str, + help="path to latest checkpoint (default: none)", + ) + parser.add_argument( + "--precision", + choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "pure_bf16", "pure_fp16", "fp32"], + default="amp", + help="Floating point precision." + ) + parser.add_argument( + "--model", + type=str, + default="RN50", + help="Name of the vision backbone to use.", + ) + parser.add_argument( + "--pretrained", + default='', + type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--pretrained-image", + default=False, + action='store_true', + help="Load imagenet pretrained weights for image tower backbone if available.", + ) + parser.add_argument( + "--lock-image", + default=False, + action='store_true', + help="Lock full image tower by disabling gradients.", + ) + parser.add_argument( + "--lock-image-unlocked-groups", + type=int, + default=0, + help="Leave last n image tower layer groups unlocked.", + ) + parser.add_argument( + "--lock-image-freeze-bn-stats", + default=False, + action='store_true', + help="Freeze BatchNorm running stats in image tower for any locked layers.", + ) + parser.add_argument( + '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override default image mean value of dataset') + parser.add_argument( + '--image-std', type=float, nargs='+', default=None, metavar='STD', + help='Override default image std deviation of of dataset') + parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs) + parser.add_argument( + "--grad-checkpointing", + default=False, + action='store_true', + help="Enable gradient checkpointing.", + ) + parser.add_argument( + "--local-loss", + default=False, + action="store_true", + help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)" + ) + parser.add_argument( + "--gather-with-grad", + default=False, + action="store_true", + help="enable full distributed gradient for feature gather" + ) + parser.add_argument( + '--force-image-size', type=int, nargs='+', default=None, + help='Override default image size' + ) + parser.add_argument( + "--force-quick-gelu", + default=False, + action='store_true', + help="Force use of QuickGELU activation for non-OpenAI transformer models.", + ) + parser.add_argument( + "--force-patch-dropout", + default=None, + type=float, + help="Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper", + ) + parser.add_argument( + "--force-custom-text", + default=False, + action='store_true', + help="Force use of CustomTextCLIP model (separate text-tower).", + ) + parser.add_argument( + "--torchscript", + default=False, + action='store_true', + help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", + ) + parser.add_argument( + "--torchcompile", + default=False, + action='store_true', + help="torch.compile() the model, requires pytorch 2.0 or later.", + ) + parser.add_argument( + "--trace", + default=False, + action='store_true', + help="torch.jit.trace the model for inference / eval only", + ) + parser.add_argument( + "--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps." + ) + # arguments for distributed training + parser.add_argument( + "--dist-url", + default="env://", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + parser.add_argument( + "--report-to", + default='', + type=str, + help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']" + ) + parser.add_argument( + "--wandb-notes", + default='', + type=str, + help="Notes if logging with wandb" + ) + parser.add_argument( + "--wandb-project-name", + type=str, + default='open-clip', + help="Name of the project if logging with wandb.", + ) + parser.add_argument( + "--debug", + default=False, + action="store_true", + help="If true, more information is logged." + ) + parser.add_argument( + "--copy-codebase", + default=False, + action="store_true", + help="If true, we copy the entire base on the log directory, and execute from there." + ) + parser.add_argument( + "--horovod", + default=False, + action="store_true", + help="Use horovod for distributed training." + ) + parser.add_argument( + "--ddp-static-graph", + default=False, + action='store_true', + help="Enable static graph optimization for DDP in PyTorch >= 1.11.", + ) + parser.add_argument( + "--no-set-device-rank", + default=False, + action="store_true", + help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)." + ) + parser.add_argument( + "--seed", type=int, default=0, help="Default random seed." + ) + parser.add_argument( + "--grad-clip-norm", type=float, default=None, help="Gradient clip." + ) + parser.add_argument( + "--lock-text", + default=False, + action='store_true', + help="Lock full text tower by disabling gradients.", + ) + parser.add_argument( + "--lock-text-unlocked-layers", + type=int, + default=0, + help="Leave last n image tower layer groups unlocked.", + ) + parser.add_argument( + "--lock-text-freeze-layer-norm", + default=False, + action='store_true', + help="Freeze BatchNorm running stats in image tower for any locked layers.", + ) + parser.add_argument( + "--log-every-n-steps", + type=int, + default=100, + help="Log every n steps to tensorboard/console/wandb.", + ) + parser.add_argument( + "--coca-caption-loss-weight", + type=float, + default=2.0, + help="Weight assigned to caption loss in CoCa." + ) + parser.add_argument( + "--coca-contrastive-loss-weight", + type=float, + default=1.0, + help="Weight assigned to contrastive loss when training CoCa." + ) + parser.add_argument( + "--remote-sync", + type=str, + default=None, + help="Optinoally sync with a remote path specified by this arg", + ) + parser.add_argument( + "--remote-sync-frequency", + type=int, + default=300, + help="How frequently to sync to a remote directly if --remote-sync is not None.", + ) + parser.add_argument( + "--remote-sync-protocol", + choices=["s3", "fsspec"], + default="s3", + help="How to do the remote sync backup if --remote-sync is not None.", + ) + parser.add_argument( + "--delete-previous-checkpoint", + default=False, + action="store_true", + help="If true, delete previous checkpoint after storing a new one." + ) + parser.add_argument( + "--distill-model", + default=None, + help='Which model arch to distill from, if any.' + ) + parser.add_argument( + "--distill-pretrained", + default=None, + help='Which pre-trained weights to distill from, if any.' + ) + parser.add_argument( + "--use-bnb-linear", + default=None, + help='Replace the network linear layers from the bitsandbytes library. ' + 'Allows int8 training/inference, etc.' + ) + args = parser.parse_args(args) + + # If some params are not passed, we use the default values based on model name. + default_params = get_default_params(args.model) + for name, val in default_params.items(): + if getattr(args, name) is None: + setattr(args, name, val) + + return args diff --git a/training/precision.py b/training/precision.py new file mode 100644 index 0000000000000000000000000000000000000000..a63b92256518d13afd57261df1568e26b1622201 --- /dev/null +++ b/training/precision.py @@ -0,0 +1,12 @@ +import torch +from contextlib import suppress + + +def get_autocast(precision): + if precision == 'amp': + return torch.cuda.amp.autocast + elif precision == 'amp_bfloat16' or precision == 'amp_bf16': + # amp_bfloat16 is more stable than amp float16 for clip training + return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return suppress diff --git a/training/profile.py b/training/profile.py new file mode 100644 index 0000000000000000000000000000000000000000..f10372cdef306e5e199db432b23062df1c098cf9 --- /dev/null +++ b/training/profile.py @@ -0,0 +1,158 @@ +import argparse + +import torch +import open_clip +import pandas as pd +from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis + + +parser = argparse.ArgumentParser(description='OpenCLIP Profiler') + +# benchmark specific args +parser.add_argument('--model', metavar='NAME', default='', + help='model(s) to profile') +parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', + help='Output csv file for results') + + +def profile_fvcore( + model, + image_input_size=(3, 224, 224), + text_input_size=(77,), + batch_size=1, + detailed=False, + force_cpu=False +): + if force_cpu: + model = model.to('cpu') + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) + example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) + fca = FlopCountAnalysis(model, (example_image_input, example_text_input)) + aca = ActivationCountAnalysis(model, (example_image_input, example_text_input)) + if detailed: + fcs = flop_count_str(fca) + print(fcs) + return fca.total(), aca.total() + + +def profile_fvcore_text( + model, + text_input_size=(77,), + batch_size=1, + detailed=False, + force_cpu=False +): + if force_cpu: + model = model.to('cpu') + device = next(model.parameters()).device + example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) + fca = FlopCountAnalysis(model, example_input) + aca = ActivationCountAnalysis(model, example_input) + if detailed: + fcs = flop_count_str(fca) + print(fcs) + return fca.total(), aca.total() + + +def profile_fvcore_image( + model, + image_input_size=(3, 224, 224), + batch_size=1, + detailed=False, + force_cpu=False +): + if force_cpu: + model = model.to('cpu') + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) + fca = FlopCountAnalysis(model, example_input) + aca = ActivationCountAnalysis(model, example_input) + if detailed: + fcs = flop_count_str(fca) + print(fcs) + return fca.total(), aca.total() + + +def count_params(model): + return sum([m.numel() for m in model.parameters()]) + + +def profile_model(model_name): + model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False) + model.eval() + if torch.cuda.is_available(): + model = model.cuda() + + if isinstance(model.visual.image_size, (tuple, list)): + image_input_size = (3,) + tuple(model.visual.image_size[-2:]) + else: + image_input_size = (3, model.visual.image_size, model.visual.image_size) + text_input_size = (77,) + + results = {} + results['model'] = model_name + results['image_size'] = image_input_size[1] + + model_cfg = open_clip.get_model_config(model_name) + if model_cfg: + vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg']) + text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg']) + results['image_width'] = int(vision_cfg.width) + results['text_width'] = int(text_cfg.width) + results['embed_dim'] = int(model_cfg['embed_dim']) + else: + results['image_width'] = 0 + results['text_width'] = 0 + results['embed_dim'] = 0 + + retries = 2 + while retries: + retries -= 1 + try: + macs, acts = profile_fvcore( + model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries) + + image_macs, image_acts = profile_fvcore_image( + model.visual, image_input_size=image_input_size, force_cpu=not retries) + + text_macs, text_acts = profile_fvcore_text( + model.text, text_input_size=text_input_size, force_cpu=not retries) + + results['gmacs'] = round(macs / 1e9, 2) + results['macts'] = round(acts / 1e6, 2) + results['mparams'] = round(count_params(model) / 1e6, 2) + results['image_gmacs'] = round(image_macs / 1e9, 2) + results['image_macts'] = round(image_acts / 1e6, 2) + results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) + results['text_gmacs'] = round(text_macs / 1e9, 2) + results['text_macts'] = round(text_acts / 1e6, 2) + results['text_mparams'] = round(count_params(model.text) / 1e6, 2) + except RuntimeError as e: + pass + return results + + +def main(): + args = parser.parse_args() + + # FIXME accept a text file name to allow lists of models in txt/csv + if args.model == 'all': + parsed_model = open_clip.list_models() + else: + parsed_model = args.model.split(',') + + results = [] + for m in parsed_model: + row = profile_model(m) + results.append(row) + + df = pd.DataFrame(results, columns=results[0].keys()) + df = df.sort_values('gmacs') + print(df) + if args.results_file: + df.to_csv(args.results_file, index=False) + + +if __name__ == '__main__': + main() diff --git a/training/scheduler.py b/training/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..fba76fcf1720b11d136a5ab6d3a58ab2fbe42f74 --- /dev/null +++ b/training/scheduler.py @@ -0,0 +1,53 @@ +import numpy as np + + +def assign_learning_rate(optimizer, new_lr): + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + + +def _warmup_lr(base_lr, warmup_length, step): + return base_lr * (step + 1) / warmup_length + + +def const_lr(optimizer, base_lr, warmup_length, steps): + def _lr_adjuster(step): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + lr = base_lr + assign_learning_rate(optimizer, lr) + return lr + return _lr_adjuster + + +def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): + def _lr_adjuster(step): + start_cooldown_step = steps - cooldown_steps + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + if step < start_cooldown_step: + lr = base_lr + else: + e = step - start_cooldown_step + es = steps - start_cooldown_step + # linear decay if power == 1; polynomial decay otherwise; + decay = (1 - (e/es)) ** cooldown_power + lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr + assign_learning_rate(optimizer, lr) + return lr + return _lr_adjuster + + +def cosine_lr(optimizer, base_lr, warmup_length, steps): + def _lr_adjuster(step): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + e = step - warmup_length + es = steps - warmup_length + lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr + assign_learning_rate(optimizer, lr) + return lr + return _lr_adjuster diff --git a/training/train.py b/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..30c127aba6ab1d8c5e457a22a0ce1bce1eb0f95c --- /dev/null +++ b/training/train.py @@ -0,0 +1,363 @@ +import json +import logging +import math +import os +import time + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.parallel.distributed import DistributedDataParallel + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import get_input_dtype, CLIP, CustomTextCLIP +from .distributed import is_master +from .zero_shot import zero_shot_eval +from .precision import get_autocast + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def postprocess_clip_output(model_out): + return { + "image_features": model_out[0], + "text_features": model_out[1], + "logit_scale": model_out[2] + } + +def unwrap_model(model): + if hasattr(model, 'module'): + return model.module + else: + return model + + +def backward(total_loss, scaler): + if scaler is not None: + scaler.scale(total_loss).backward() + else: + total_loss.backward() + + +def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None): + device = torch.device(args.device) + autocast = get_autocast(args.precision) + input_dtype = get_input_dtype(args.precision) + + + model.train() + if args.distill: + dist_model.eval() + + data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch + dataloader = data['train'].dataloader + num_batches_per_epoch = dataloader.num_batches // args.accum_freq + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + if args.accum_freq > 1: + accum_images, accum_texts, accum_features = [], [], {} + + losses_m = {} + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + for i, batch in enumerate(dataloader): + i_accum = i // args.accum_freq + step = num_batches_per_epoch * epoch + i_accum + + if not args.skip_scheduler: + scheduler(step) + + images, texts = batch + images = images.to(device=device, dtype=input_dtype, non_blocking=True) + texts = texts.to(device=device, non_blocking=True) + # images = images.to(device=device, dtype=input_dtype, non_blocking=False) + # texts = texts.to(device=device, non_blocking=False) + + data_time_m.update(time.time() - end) + optimizer.zero_grad() + + if args.accum_freq == 1: + with autocast(): + model_out = model(images, texts) + logit_scale = model_out["logit_scale"] + if args.distill: + with torch.no_grad(): + dist_model_out = dist_model(images, texts) + model_out.update({f'dist_{k}': v for k, v in dist_model_out.items()}) + losses = loss(**model_out, output_dict=True) + + total_loss = sum(losses.values()) + losses["loss"] = total_loss + + backward(total_loss, scaler) + else: + # First, cache the features without any gradient tracking. + with torch.no_grad(): + with autocast(): + model_out = model(images, texts) + model_out.pop("logit_scale") + for key, val in model_out.items(): + if key in accum_features: + accum_features[key].append(val) + else: + accum_features[key] = [val] + + accum_images.append(images) + accum_texts.append(texts) + + # If (i + 1) % accum_freq is not zero, move on to the next batch. + if ((i + 1) % args.accum_freq) > 0: + # FIXME this makes data time logging unreliable when accumulating + continue + + # Now, ready to take gradients for the last accum_freq batches. + # Re-do the forward pass for those batches, and use the cached features from the other batches as negatives. + # Call backwards each time, but only step optimizer at the end. + optimizer.zero_grad() + for j in range(args.accum_freq): + images = accum_images[j] + texts = accum_texts[j] + with autocast(): + model_out = model(images, texts) + logit_scale = model_out.pop("logit_scale") + inputs = {} + for key, val in accum_features.items(): + accumulated = accum_features[key] + inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:]) + losses = loss(**inputs, logit_scale=logit_scale, output_dict=True) + del inputs + total_loss = sum(losses.values()) + losses["loss"] = total_loss + backward(total_loss, scaler) + + if scaler is not None: + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + if args.grad_clip_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + if args.grad_clip_norm is not None: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + scaler.step(optimizer) + scaler.update() + else: + if args.grad_clip_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) + optimizer.step() + + # reset gradient accum, if enabled + if args.accum_freq > 1: + accum_images, accum_texts, accum_features = [], [], {} + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i_accum + 1 + if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch): + batch_size = len(images) + num_samples = batch_count * batch_size * args.accum_freq * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + for key, val in losses.items(): + if key not in losses_m: + losses_m[key] = AverageMeter() + losses_m[key].update(val.item(), batch_size) + + logit_scale_scalar = logit_scale.item() + loss_log = " ".join( + [ + f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" + for loss_name, loss_m in losses_m.items() + ] + ) + samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val + samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "samples_per_second": samples_per_second, + "samples_per_second_per_gpu": samples_per_second_per_gpu, + "scale": logit_scale_scalar, + "lr": optimizer.param_groups[0]["lr"] + } + log_data.update({name:val.val for name,val in losses_m.items()}) + + for name, val in log_data.items(): + name = "train/" + name + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, 'Please install wandb.' + wandb.log({name: val, 'step': step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None): + metrics = {} + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + metrics.update(zero_shot_metrics) + + autocast = get_autocast(args.precision) + input_dtype = get_input_dtype(args.precision) + + if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): + dataloader = data['val'].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + # FIXME this does not scale past small eval datasets + # all_image_features @ all_text_features will blow up memory and compute very quickly + cumulative_loss = 0.0 + cumulative_gen_loss = 0.0 + all_image_features, all_text_features = [], [] + with torch.no_grad(): + for i, batch in enumerate(dataloader): + images, texts = batch + images = images.to(device=device, dtype=input_dtype, non_blocking=True) + texts = texts.to(device=device, non_blocking=True) + + with autocast(): + model_out = model(images, texts) + image_features = model_out["image_features"] + text_features = model_out["text_features"] + logit_scale = model_out["logit_scale"] + # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly + # however, system RAM is easily exceeded and compute time becomes problematic + all_image_features.append(image_features.cpu()) + all_text_features.append(text_features.cpu()) + logit_scale = logit_scale.mean() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + batch_size = images.shape[0] + labels = torch.arange(batch_size, device=device).long() + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + gen_loss = maybe_compute_generative_loss(model_out) + + cumulative_loss += total_loss * batch_size + num_samples += batch_size + if is_master(args) and (i % 100) == 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" + f"Clip Loss: {cumulative_loss / num_samples:.6f}\t") + + if gen_loss is not None: + cumulative_gen_loss += gen_loss * batch_size + logging.info( + f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t") + + val_metrics = get_clip_metrics( + image_features=torch.cat(all_image_features), + text_features=torch.cat(all_text_features), + logit_scale=logit_scale.cpu(), + ) + loss = cumulative_loss / num_samples + metrics.update( + {**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} + ) + if gen_loss is not None: + gen_loss = cumulative_gen_loss / num_samples + metrics.update({"val_generative_loss": gen_loss.item()}) + + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) + ) + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, 'Please install wandb.' + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, 'epoch': epoch}) + + return metrics + + +def get_clip_metrics(image_features, text_features, logit_scale): + metrics = {} + logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() + logits_per_text = logits_per_image.t().detach().cpu() + + logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + for name, logit in logits.items(): + ranking = torch.argsort(logit, descending=True) + preds = torch.where(ranking == ground_truth)[1] + preds = preds.detach().cpu().numpy() + metrics[f"{name}_mean_rank"] = preds.mean() + 1 + metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = np.mean(preds < k) + + return metrics + + +def maybe_compute_generative_loss(model_out): + if "logits" in model_out and "labels" in model_out: + token_logits = model_out["logits"] + token_labels = model_out["labels"] + return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) diff --git a/training/zero_shot.py b/training/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..8265b424b247030abbb7d4ede289a0f890fdcdd4 --- /dev/null +++ b/training/zero_shot.py @@ -0,0 +1,84 @@ +import logging + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ + IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES +from .precision import get_autocast + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] + + +def run(model, classifier, dataloader, args): + autocast = get_autocast(args.precision) + input_dtype = get_input_dtype(args.precision) + + with torch.no_grad(): + top1, top5, n = 0., 0., 0. + for images, target in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(device=args.device, dtype=input_dtype) + target = target.to(args.device) + + with autocast(): + # predict + output = model(image=images) + image_features = output['image_features'] if isinstance(output, dict) else output[0] + logits = 100. * image_features @ classifier + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = (top1 / n) + top5 = (top5 / n) + return top1, top5 + + +def zero_shot_eval(model, data, epoch, args): + if 'imagenet-val' not in data and 'imagenet-v2' not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + if args.distributed and not args.horovod: + model = model.module + + logging.info('Starting zero-shot imagenet.') + + logging.info('Building zero-shot classifier') + autocast = get_autocast(args.precision) + with autocast(): + tokenizer = get_tokenizer(args.model) + classifier = build_zero_shot_classifier( + model, + tokenizer=tokenizer, + classnames=IMAGENET_CLASSNAMES, + templates=OPENAI_IMAGENET_TEMPLATES, + num_classes_per_batch=10, + device=args.device, + use_tqdm=True, + ) + + logging.info('Using classifier') + results = {} + if 'imagenet-val' in data: + top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) + results['imagenet-zeroshot-val-top1'] = top1 + results['imagenet-zeroshot-val-top5'] = top5 + if 'imagenet-v2' in data: + top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) + results['imagenetv2-zeroshot-val-top1'] = top1 + results['imagenetv2-zeroshot-val-top5'] = top5 + + logging.info('Finished zero-shot imagenet.') + + return results diff --git a/v_cls/__init__.py b/v_cls/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0a87061c2c9a228ccbe1597b6cdd5a580537d9 --- /dev/null +++ b/v_cls/__init__.py @@ -0,0 +1,110 @@ +import os + +import torch +from functools import partial +from .build import build_dataset, build_pretraining_dataset +from torch.utils.data._utils.collate import default_collate + +__all__ = ['build_dataset', 'build_pretraining_dataset'] + + +def multiple_samples_collate(batch, fold=False): + """ + Collate function for repeated augmentation. Each instance in the batch has + more than one sample. + Args: + batch (tuple or list): data batch to collate. + Returns: + (tuple): collated data batch. + """ + inputs, labels, video_idx, extra_data = zip(*batch) + inputs = [item for sublist in inputs for item in sublist] + labels = [item for sublist in labels for item in sublist] + video_idx = [item for sublist in video_idx for item in sublist] + inputs, labels, video_idx, extra_data = ( + default_collate(inputs), + default_collate(labels), + default_collate(video_idx), + default_collate(extra_data), + ) + if fold: + return [inputs], labels, video_idx, extra_data + else: + return inputs, labels, video_idx, extra_data + +def get_video_cls_dataloader(args): + dataset_train, args.nb_classes = build_dataset(is_train=True, test_mode=False, args=args) + # if args.disable_eval_during_finetuning: + # dataset_val = None + # else: + dataset_val, _ = build_dataset(is_train=False, test_mode=False, args=args) + dataset_test, _ = build_dataset(is_train=False, test_mode=True, args=args) + + num_tasks = args.world_size + global_rank = args.rank + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) + # print("Sampler_train = %s" % str(sampler_train)) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print( + 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' + 'This will slightly alter validation results as extra duplicate entries are added to achieve ' + 'equal num of samples per-process.') + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, + num_replicas=num_tasks, + rank=global_rank, + shuffle=False) + sampler_test = torch.utils.data.DistributedSampler( + dataset_test, + num_replicas=num_tasks, + rank=global_rank, + shuffle=False) + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + if args.num_sample > 1: + collate_func = partial(multiple_samples_collate, fold=False) + else: + collate_func = None + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, + sampler=sampler_train, + batch_size=args.batch_size, + # batch_size=16, ###################################### + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + collate_fn=collate_func, + persistent_workers=True) + + if dataset_val is not None: + data_loader_val = torch.utils.data.DataLoader( + dataset_val, + sampler=sampler_val, + batch_size=int(1.5 * args.batch_size), + # batch_size=16, #################################### + num_workers=args.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=True) + else: + data_loader_val = None + + if dataset_test is not None: + data_loader_test = torch.utils.data.DataLoader( + dataset_test, + sampler=sampler_test, + batch_size=args.batch_size, + # batch_size=16, ##################################### + num_workers=args.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=True) + else: + data_loader_test = None + + # return data_loader_train, data_loader_val, data_loader_test + return data_loader_test \ No newline at end of file diff --git a/v_cls/build.py b/v_cls/build.py new file mode 100644 index 0000000000000000000000000000000000000000..a262f7c7c239f6295eab3527aad85d88fd0b58a1 --- /dev/null +++ b/v_cls/build.py @@ -0,0 +1,349 @@ +# -------------------------------------------------------- +# Based on BEiT, timm, DINO and DeiT code bases +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/facebookresearch/deit +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import os + +from torchvision import transforms + +from .datasets import RawFrameClsDataset, VideoClsDataset +from .masking_generator import ( + RunningCellMaskingGenerator, + TubeMaskingGenerator, +) +from .pretrain_datasets import HybridVideoMAE, VideoMAE # noqa: F401 +from .transforms import ( + GroupMultiScaleCrop, + GroupNormalize, + Stack, + ToTorchFormatTensor, +) + + +class DataAugmentationForVideoMAEv2(object): + + def __init__(self, args): + self.input_mean = [0.485, 0.456, 0.406] + self.input_std = [0.229, 0.224, 0.225] + div = True + roll = False + normalize = GroupNormalize(self.input_mean, self.input_std) + self.train_augmentation = GroupMultiScaleCrop(args.input_size, + [1, .875, .75, .66]) + self.transform = transforms.Compose([ + self.train_augmentation, + Stack(roll=roll), + ToTorchFormatTensor(div=div), + normalize, + ]) + if args.mask_type == 'tube': + self.encoder_mask_map_generator = TubeMaskingGenerator( + args.window_size, args.mask_ratio) + else: + raise NotImplementedError( + 'Unsupported encoder masking strategy type.') + if args.decoder_mask_ratio > 0.: + if args.decoder_mask_type == 'run_cell': + self.decoder_mask_map_generator = RunningCellMaskingGenerator( + args.window_size, args.decoder_mask_ratio) + else: + raise NotImplementedError( + 'Unsupported decoder masking strategy type.') + + def __call__(self, images): + process_data, _ = self.transform(images) + encoder_mask_map = self.encoder_mask_map_generator() + if hasattr(self, 'decoder_mask_map_generator'): + decoder_mask_map = self.decoder_mask_map_generator() + else: + decoder_mask_map = 1 - encoder_mask_map + return process_data, encoder_mask_map, decoder_mask_map + + def __repr__(self): + repr = "(DataAugmentationForVideoMAEv2,\n" + repr += " transform = %s,\n" % str(self.transform) + repr += " Encoder Masking Generator = %s,\n" % str( + self.encoder_mask_map_generator) + if hasattr(self, 'decoder_mask_map_generator'): + repr += " Decoder Masking Generator = %s,\n" % str( + self.decoder_mask_map_generator) + else: + repr += " Do not use decoder masking,\n" + repr += ")" + return repr + + +def build_pretraining_dataset(args): + transform = DataAugmentationForVideoMAEv2(args) + dataset = VideoMAE( + root=args.data_root, + setting=args.video_data_path, + train=True, + test_mode=False, + name_pattern=args.fname_tmpl, + video_ext='mp4', + is_color=True, + modality='rgb', + num_segments=1, + num_crop=1, + new_length=args.num_frames, + new_step=args.sampling_rate, + transform=transform, + temporal_jitter=False, + lazy_init=False, + num_sample=args.num_sample) + # print("Data Aug = %s" % str(transform)) + return dataset + + +def build_dataset(is_train, test_mode, args): + if is_train: + mode = 'train' + anno_path = os.path.join(args.video_data_path, 'train.csv') + elif test_mode: + mode = 'test' + anno_path = os.path.join(args.video_data_path, 'val.csv') + else: + mode = 'validation' + anno_path = os.path.join(args.video_data_path, 'val.csv') + + if args.data_set == 'Kinetics-400': + if not args.sparse_sample: + dataset = VideoClsDataset( + anno_path=anno_path, + data_root=args.data_root, + mode=mode, + clip_len=args.num_frames, + frame_sample_rate=args.sampling_rate, + num_segment=1, + test_num_segment=args.test_num_segment, + test_num_crop=args.test_num_crop, + num_crop=1 if not test_mode else 3, + keep_aspect_ratio=True, + crop_size=args.input_size, + short_side_size=args.short_side_size, + new_height=256, + new_width=320, + sparse_sample=False, + args=args) + else: + dataset = VideoClsDataset( + anno_path=anno_path, + data_root=args.data_root, + mode=mode, + clip_len=1, + frame_sample_rate=1, + num_segment=args.num_frames, + test_num_segment=args.test_num_segment, + test_num_crop=args.test_num_crop, + num_crop=1 if not test_mode else 3, + keep_aspect_ratio=True, + crop_size=args.input_size, + short_side_size=args.short_side_size, + new_height=256, + new_width=320, + sparse_sample=True, + args=args) + nb_classes = 400 + + elif args.data_set == 'Kinetics-600': + dataset = VideoClsDataset( + anno_path=anno_path, + data_root=args.data_root, + mode=mode, + clip_len=args.num_frames, + frame_sample_rate=args.sampling_rate, + num_segment=1, + test_num_segment=args.test_num_segment, + test_num_crop=args.test_num_crop, + num_crop=1 if not test_mode else 3, + keep_aspect_ratio=True, + crop_size=args.input_size, + short_side_size=args.short_side_size, + new_height=256, + new_width=320, + args=args) + nb_classes = 600 + + elif args.data_set == 'Kinetics-700': + dataset = VideoClsDataset( + anno_path=anno_path, + data_root=args.data_root, + mode=mode, + clip_len=args.num_frames, + frame_sample_rate=args.sampling_rate, + num_segment=1, + test_num_segment=args.test_num_segment, + test_num_crop=args.test_num_crop, + num_crop=1 if not test_mode else 3, + keep_aspect_ratio=True, + crop_size=args.input_size, + short_side_size=args.short_side_size, + new_height=256, + new_width=320, + args=args) + nb_classes = 700 + + elif args.data_set == 'Kinetics-710': + dataset = VideoClsDataset( + anno_path=anno_path, + data_root=args.data_root, + mode=mode, + clip_len=args.num_frames, + frame_sample_rate=args.sampling_rate, + num_segment=1, + test_num_segment=args.test_num_segment, + test_num_crop=args.test_num_crop, + num_crop=1 if not test_mode else 3, + keep_aspect_ratio=True, + crop_size=args.input_size, + short_side_size=args.short_side_size, + new_height=256, + new_width=320, + args=args) + nb_classes = 710 + + elif args.data_set == 'SSV2': + dataset = VideoClsDataset( + anno_path=anno_path, + data_root=args.data_root, + mode=mode, + clip_len=args.num_frames, + frame_sample_rate=args.sampling_rate, + num_segment=1, + test_num_segment=args.test_num_segment, + test_num_crop=args.test_num_crop, + num_crop=1 if not test_mode else 3, + keep_aspect_ratio=True, + crop_size=args.input_size, + short_side_size=args.short_side_size, + new_height=256, + new_width=320, + args=args) + nb_classes = 174 + + # elif args.data_set == 'SSV2': + # dataset = RawFrameClsDataset( + # anno_path=anno_path, + # data_root=args.data_root, + # mode=mode, + # clip_len=1, + # num_segment=args.num_frames, + # test_num_segment=args.test_num_segment, + # test_num_crop=args.test_num_crop, + # num_crop=1 if not test_mode else 3, + # keep_aspect_ratio=True, + # crop_size=args.input_size, + # short_side_size=args.short_side_size, + # new_height=256, + # new_width=320, + # filename_tmpl=args.fname_tmpl, + # start_idx=args.start_idx, + # args=args) + # + # nb_classes = 174 + + elif args.data_set == 'UCF101': + dataset = VideoClsDataset( + anno_path=anno_path, + data_root=args.data_root, + mode=mode, + clip_len=args.num_frames, + frame_sample_rate=args.sampling_rate, + num_segment=1, + test_num_segment=args.test_num_segment, + test_num_crop=args.test_num_crop, + num_crop=1 if not test_mode else 3, + keep_aspect_ratio=True, + crop_size=args.input_size, + short_side_size=args.short_side_size, + new_height=256, + new_width=320, + args=args) + nb_classes = 101 + + elif args.data_set == 'HMDB51': + dataset = VideoClsDataset( + anno_path=anno_path, + data_root=args.data_root, + mode=mode, + clip_len=args.num_frames, + frame_sample_rate=args.sampling_rate, + num_segment=1, + test_num_segment=args.test_num_segment, + test_num_crop=args.test_num_crop, + num_crop=1 if not test_mode else 3, + keep_aspect_ratio=True, + crop_size=args.input_size, + short_side_size=args.short_side_size, + new_height=256, + new_width=320, + args=args) + nb_classes = 51 + + elif args.data_set == 'Diving48': + dataset = VideoClsDataset( + anno_path=anno_path, + data_root=args.data_root, + mode=mode, + clip_len=args.num_frames, + frame_sample_rate=args.sampling_rate, + num_segment=1, + test_num_segment=args.test_num_segment, + test_num_crop=args.test_num_crop, + num_crop=1 if not test_mode else 3, + keep_aspect_ratio=True, + crop_size=args.input_size, + short_side_size=args.short_side_size, + new_height=256, + new_width=320, + args=args) + nb_classes = 48 + elif args.data_set == 'MIT': + if not args.sparse_sample: + dataset = VideoClsDataset( + anno_path=anno_path, + data_root=args.data_root, + mode=mode, + clip_len=args.num_frames, + frame_sample_rate=args.sampling_rate, + num_segment=1, + test_num_segment=args.test_num_segment, + test_num_crop=args.test_num_crop, + num_crop=1 if not test_mode else 3, + keep_aspect_ratio=True, + crop_size=args.input_size, + short_side_size=args.short_side_size, + new_height=256, + new_width=320, + sparse_sample=False, + args=args) + else: + dataset = VideoClsDataset( + anno_path=anno_path, + data_root=args.data_root, + mode=mode, + clip_len=1, + frame_sample_rate=1, + num_segment=args.num_frames, + test_num_segment=args.test_num_segment, + test_num_crop=args.test_num_crop, + num_crop=1 if not test_mode else 3, + keep_aspect_ratio=True, + crop_size=args.input_size, + short_side_size=args.short_side_size, + new_height=256, + new_width=320, + sparse_sample=True, + args=args) + nb_classes = 339 + else: + raise NotImplementedError('Unsupported Dataset') + + assert nb_classes == args.nb_classes + # print("Number of the class = %d" % args.nb_classes) + + return dataset, nb_classes diff --git a/v_cls/datasets.py b/v_cls/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..847a51cdd2f1a6aabd7b2d20e2ed6e312bb1b0ec --- /dev/null +++ b/v_cls/datasets.py @@ -0,0 +1,715 @@ +# pylint: disable=line-too-long,too-many-lines,missing-docstring +import os +import warnings + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset +from torchvision import transforms + +from . import video_transforms, volume_transforms +from .loader import get_image_loader, get_video_loader +from .random_erasing import RandomErasing + + +class VideoClsDataset(Dataset): + """Load your own video classification dataset.""" + + def __init__(self, + anno_path, + data_root='', + mode='train', + clip_len=8, + frame_sample_rate=2, + crop_size=224, + short_side_size=256, + new_height=256, + new_width=340, + keep_aspect_ratio=True, + num_segment=1, + num_crop=1, + test_num_segment=10, + test_num_crop=3, + sparse_sample=False, + args=None): + self.anno_path = anno_path + self.data_root = data_root + self.mode = mode + self.clip_len = clip_len + self.frame_sample_rate = frame_sample_rate + self.crop_size = crop_size + self.short_side_size = short_side_size + self.new_height = new_height + self.new_width = new_width + self.keep_aspect_ratio = keep_aspect_ratio + self.num_segment = num_segment + self.test_num_segment = test_num_segment + self.num_crop = num_crop + self.test_num_crop = test_num_crop + self.sparse_sample = sparse_sample + self.args = args + self.aug = False + self.rand_erase = False + + if self.mode in ['train']: + self.aug = True + if self.args.reprob > 0: + self.rand_erase = True + + self.video_loader = get_video_loader() + + cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ') + self.dataset_samples = list( + cleaned[0].apply(lambda row: os.path.join(self.data_root, row))) + self.label_array = list(cleaned.values[:, 1]) + + if (mode == 'train'): + pass + + elif (mode == 'validation'): + self.data_transform = video_transforms.Compose([ + video_transforms.Resize( + self.short_side_size, interpolation='bilinear'), + video_transforms.CenterCrop( + size=(self.crop_size, self.crop_size)), + volume_transforms.ClipToTensor(), + video_transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + elif mode == 'test': + self.data_resize = video_transforms.Compose([ + video_transforms.Resize( + size=(short_side_size), interpolation='bilinear') + ]) + self.data_transform = video_transforms.Compose([ + volume_transforms.ClipToTensor(), + video_transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.test_seg = [] + self.test_dataset = [] + self.test_label_array = [] + for ck in range(self.test_num_segment): + for cp in range(self.test_num_crop): + for idx in range(len(self.label_array)): + sample_label = self.label_array[idx] + self.test_label_array.append(sample_label) + self.test_dataset.append(self.dataset_samples[idx]) + self.test_seg.append((ck, cp)) + + def __getitem__(self, index): + if self.mode == 'train': + args = self.args + scale_t = 1 + + sample = self.dataset_samples[index] + # T H W C + buffer = self.load_video(sample, sample_rate_scale=scale_t) + if len(buffer) == 0: + while len(buffer) == 0: + warnings.warn( + "video {} not correctly loaded during training".format( + sample)) + index = np.random.randint(self.__len__()) + sample = self.dataset_samples[index] + buffer = self.load_video(sample, sample_rate_scale=scale_t) + + if args.num_sample > 1: + frame_list = [] + label_list = [] + index_list = [] + for _ in range(args.num_sample): + new_frames = self._aug_frame(buffer, args) + label = self.label_array[index] + frame_list.append(new_frames) + label_list.append(label) + index_list.append(index) + return frame_list, label_list, index_list, {} + else: + buffer = self._aug_frame(buffer, args) + + return buffer, self.label_array[index], index, {} + + elif self.mode == 'validation': + sample = self.dataset_samples[index] + buffer = self.load_video(sample) + if len(buffer) == 0: + while len(buffer) == 0: + warnings.warn( + "video {} not correctly loaded during validation". + format(sample)) + index = np.random.randint(self.__len__()) + sample = self.dataset_samples[index] + buffer = self.load_video(sample) + buffer = self.data_transform(buffer) + return buffer, self.label_array[index], sample.split( + "/")[-1].split(".")[0] + + elif self.mode == 'test': + sample = self.test_dataset[index] + chunk_nb, split_nb = self.test_seg[index] + buffer = self.load_video(sample) + + while len(buffer) == 0: + warnings.warn( + "video {}, temporal {}, spatial {} not found during testing" + .format(str(self.test_dataset[index]), chunk_nb, split_nb)) + index = np.random.randint(self.__len__()) + sample = self.test_dataset[index] + chunk_nb, split_nb = self.test_seg[index] + buffer = self.load_video(sample) + + buffer = self.data_resize(buffer) + if isinstance(buffer, list): + buffer = np.stack(buffer, 0) + + if self.sparse_sample: + spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - + self.short_side_size) / ( + self.test_num_crop - 1) + temporal_start = chunk_nb + spatial_start = int(split_nb * spatial_step) + if buffer.shape[1] >= buffer.shape[2]: + buffer = buffer[temporal_start::self.test_num_segment, + spatial_start:spatial_start + + self.short_side_size, :, :] + else: + buffer = buffer[temporal_start::self.test_num_segment, :, + spatial_start:spatial_start + + self.short_side_size, :] + else: + spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - + self.short_side_size) / ( + self.test_num_crop - 1) + temporal_step = max( + 1.0 * (buffer.shape[0] - self.clip_len) / + (self.test_num_segment - 1), 0) + temporal_start = int(chunk_nb * temporal_step) + spatial_start = int(split_nb * spatial_step) + if buffer.shape[1] >= buffer.shape[2]: + buffer = buffer[temporal_start:temporal_start + + self.clip_len, + spatial_start:spatial_start + + self.short_side_size, :, :] + else: + buffer = buffer[temporal_start:temporal_start + + self.clip_len, :, + spatial_start:spatial_start + + self.short_side_size, :] + + buffer = self.data_transform(buffer) + return buffer, self.test_label_array[index], sample.split( + "/")[-1].split(".")[0], chunk_nb, split_nb + else: + raise NameError('mode {} unkown'.format(self.mode)) + + def _aug_frame(self, buffer, args): + aug_transform = video_transforms.create_random_augment( + input_size=(self.crop_size, self.crop_size), + auto_augment=args.aa, + interpolation=args.train_interpolation, + ) + + buffer = [transforms.ToPILImage()(frame) for frame in buffer] + + buffer = aug_transform(buffer) + + buffer = [transforms.ToTensor()(img) for img in buffer] + buffer = torch.stack(buffer) # T C H W + buffer = buffer.permute(0, 2, 3, 1) # T H W C + + # T H W C + buffer = tensor_normalize(buffer, [0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + # T H W C -> C T H W. + buffer = buffer.permute(3, 0, 1, 2) + # Perform data augmentation. + scl, asp = ( + [0.08, 1.0], + [0.75, 1.3333], + ) + + buffer = spatial_sampling( + buffer, + spatial_idx=-1, + min_scale=256, + max_scale=320, + # crop_size=224, + crop_size=args.input_size, + random_horizontal_flip=False if args.data_set == 'SSV2' else True, + inverse_uniform_sampling=False, + aspect_ratio=asp, + scale=scl, + motion_shift=False) + + if self.rand_erase: + erase_transform = RandomErasing( + args.reprob, + mode=args.remode, + max_count=args.recount, + num_splits=args.recount, + device="cpu", + ) + buffer = buffer.permute(1, 0, 2, 3) # C T H W -> T C H W + buffer = erase_transform(buffer) + buffer = buffer.permute(1, 0, 2, 3) # T C H W -> C T H W + + return buffer + + def load_video(self, sample, sample_rate_scale=1): + fname = sample + + try: + vr = self.video_loader(fname) + except Exception as e: + print(f"Failed to load video from {fname} with error {e}!") + return [] + + length = len(vr) + + if self.mode == 'test': + if self.sparse_sample: + tick = length / float(self.num_segment) + all_index = [] + for t_seg in range(self.test_num_segment): + tmp_index = [ + int(t_seg * tick / self.test_num_segment + tick * x) + for x in range(self.num_segment) + ] + all_index.extend(tmp_index) + all_index = list(np.sort(np.array(all_index))) + else: + all_index = [ + x for x in range(0, length, self.frame_sample_rate) + ] + while len(all_index) < self.clip_len: + all_index.append(all_index[-1]) + + vr.seek(0) + buffer = vr.get_batch(all_index).asnumpy() + return buffer + + # handle temporal segments + converted_len = int(self.clip_len * self.frame_sample_rate) + seg_len = length // self.num_segment + + all_index = [] + for i in range(self.num_segment): + if seg_len <= converted_len: + index = np.linspace( + 0, seg_len, num=seg_len // self.frame_sample_rate) + index = np.concatenate( + (index, + np.ones(self.clip_len - seg_len // self.frame_sample_rate) + * seg_len)) + index = np.clip(index, 0, seg_len - 1).astype(np.int64) + else: + if self.mode == 'validation': + end_idx = (converted_len + seg_len) // 2 + else: + end_idx = np.random.randint(converted_len, seg_len) + str_idx = end_idx - converted_len + index = np.linspace(str_idx, end_idx, num=self.clip_len) + index = np.clip(index, str_idx, end_idx - 1).astype(np.int64) + index = index + i * seg_len + all_index.extend(list(index)) + + all_index = all_index[::int(sample_rate_scale)] + vr.seek(0) + buffer = vr.get_batch(all_index).asnumpy() + return buffer + + def __len__(self): + # return 200 + if self.mode != 'test': + return len(self.dataset_samples) + else: + return len(self.test_dataset) + + +class RawFrameClsDataset(Dataset): + """Load your own raw frame classification dataset.""" + + def __init__(self, + anno_path, + data_root, + mode='train', + clip_len=8, + crop_size=224, + short_side_size=256, + new_height=256, + new_width=340, + keep_aspect_ratio=True, + num_segment=1, + num_crop=1, + test_num_segment=10, + test_num_crop=3, + filename_tmpl='img_{:05}.jpg', + start_idx=1, + args=None): + self.anno_path = anno_path + self.data_root = data_root + self.mode = mode + self.clip_len = clip_len + self.crop_size = crop_size + self.short_side_size = short_side_size + self.new_height = new_height + self.new_width = new_width + self.keep_aspect_ratio = keep_aspect_ratio + self.num_segment = num_segment + self.test_num_segment = test_num_segment + self.num_crop = num_crop + self.test_num_crop = test_num_crop + self.filename_tmpl = filename_tmpl + self.start_idx = start_idx + self.args = args + self.aug = False + self.rand_erase = False + + if self.mode in ['train']: + self.aug = True + if self.args.reprob > 0: + self.rand_erase = True + + self.image_loader = get_image_loader() + + cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ') + self.dataset_samples = list( + cleaned[0].apply(lambda row: os.path.join(self.data_root, row))) + self.total_frames = list(cleaned.values[:, 1]) + self.label_array = list(cleaned.values[:, -1]) + + if (mode == 'train'): + pass + + elif (mode == 'validation'): + self.data_transform = video_transforms.Compose([ + video_transforms.Resize( + self.short_side_size, interpolation='bilinear'), + video_transforms.CenterCrop( + size=(self.crop_size, self.crop_size)), + volume_transforms.ClipToTensor(), + video_transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + elif mode == 'test': + self.data_resize = video_transforms.Compose([ + video_transforms.Resize( + size=(short_side_size), interpolation='bilinear') + ]) + self.data_transform = video_transforms.Compose([ + volume_transforms.ClipToTensor(), + video_transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.test_seg = [] + self.test_dataset = [] + self.test_total_frames = [] + self.test_label_array = [] + for ck in range(self.test_num_segment): + for cp in range(self.test_num_crop): + for idx in range(len(self.label_array)): + self.test_seg.append((ck, cp)) + self.test_dataset.append(self.dataset_samples[idx]) + self.test_total_frames.append(self.total_frames[idx]) + self.test_label_array.append(self.label_array[idx]) + + def __getitem__(self, index): + if self.mode == 'train': + args = self.args + scale_t = 1 + + sample = self.dataset_samples[index] + total_frame = self.total_frames[index] + buffer = self.load_frame( + sample, total_frame, sample_rate_scale=scale_t) # T H W C + if len(buffer) == 0: + while len(buffer) == 0: + warnings.warn( + "video {} not correctly loaded during training".format( + sample)) + index = np.random.randint(self.__len__()) + sample = self.dataset_samples[index] + total_frame = self.total_frames[index] + buffer = self.load_frame( + sample, total_frame, sample_rate_scale=scale_t) + + if args.num_sample > 1: + frame_list = [] + label_list = [] + index_list = [] + for _ in range(args.num_sample): + new_frames = self._aug_frame(buffer, args) + label = self.label_array[index] + frame_list.append(new_frames) + label_list.append(label) + index_list.append(index) + return frame_list, label_list, index_list, {} + else: + buffer = self._aug_frame(buffer, args) + + return buffer, self.label_array[index], index, {} + + elif self.mode == 'validation': + sample = self.dataset_samples[index] + total_frame = self.total_frames[index] + buffer = self.load_frame(sample, total_frame) + if len(buffer) == 0: + while len(buffer) == 0: + warnings.warn( + "video {} not correctly loaded during validation". + format(sample)) + index = np.random.randint(self.__len__()) + sample = self.dataset_samples[index] + buffer = self.load_frame(sample, total_frame) + buffer = self.data_transform(buffer) + return buffer, self.label_array[index], sample.split( + "/")[-1].split(".")[0] + + elif self.mode == 'test': + sample = self.test_dataset[index] + total_frame = self.test_total_frames[index] + chunk_nb, split_nb = self.test_seg[index] + buffer = self.load_frame(sample, total_frame) + + while len(buffer) == 0: + warnings.warn( + "video {}, temporal {}, spatial {} not found during testing" + .format(str(self.test_dataset[index]), chunk_nb, split_nb)) + index = np.random.randint(self.__len__()) + sample = self.test_dataset[index] + total_frame = self.test_total_frames[index] + chunk_nb, split_nb = self.test_seg[index] + buffer = self.load_frame(sample, total_frame) + + buffer = self.data_resize(buffer) + if isinstance(buffer, list): + buffer = np.stack(buffer, 0) + + spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - + self.short_side_size) / ( + self.test_num_crop - 1) + temporal_start = chunk_nb + spatial_start = int(split_nb * spatial_step) + if buffer.shape[1] >= buffer.shape[2]: + buffer = buffer[temporal_start::self.test_num_segment, + spatial_start:spatial_start + + self.short_side_size, :, :] + else: + buffer = buffer[temporal_start::self.test_num_segment, :, + spatial_start:spatial_start + + self.short_side_size, :] + + buffer = self.data_transform(buffer) + return buffer, self.test_label_array[index], sample.split( + "/")[-1].split(".")[0], chunk_nb, split_nb + else: + raise NameError('mode {} unkown'.format(self.mode)) + + def _aug_frame(self, buffer, args): + aug_transform = video_transforms.create_random_augment( + input_size=(self.crop_size, self.crop_size), + auto_augment=args.aa, + interpolation=args.train_interpolation, + ) + + buffer = [transforms.ToPILImage()(frame) for frame in buffer] + + buffer = aug_transform(buffer) + + buffer = [transforms.ToTensor()(img) for img in buffer] + buffer = torch.stack(buffer) # T C H W + buffer = buffer.permute(0, 2, 3, 1) # T H W C + + # T H W C + buffer = tensor_normalize(buffer, [0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + # T H W C -> C T H W. + buffer = buffer.permute(3, 0, 1, 2) + # Perform data augmentation. + scl, asp = ( + [0.08, 1.0], + [0.75, 1.3333], + ) + + buffer = spatial_sampling( + buffer, + spatial_idx=-1, + min_scale=256, + max_scale=320, + crop_size=self.crop_size, + random_horizontal_flip=False if args.data_set == 'SSV2' else True, + inverse_uniform_sampling=False, + aspect_ratio=asp, + scale=scl, + motion_shift=False) + + if self.rand_erase: + erase_transform = RandomErasing( + args.reprob, + mode=args.remode, + max_count=args.recount, + num_splits=args.recount, + device="cpu", + ) + buffer = buffer.permute(1, 0, 2, 3) + buffer = erase_transform(buffer) + buffer = buffer.permute(1, 0, 2, 3) + + return buffer + + def load_frame(self, sample, num_frames, sample_rate_scale=1): + """Load video content using Decord""" + fname = sample + + if self.mode == 'test': + tick = num_frames / float(self.num_segment) + all_index = [] + for t_seg in range(self.test_num_segment): + tmp_index = [ + int(t_seg * tick / self.test_num_segment + tick * x) + for x in range(self.num_segment) + ] + all_index.extend(tmp_index) + all_index = list(np.sort(np.array(all_index) + self.start_idx)) + imgs = [] + for idx in all_index: + frame_fname = os.path.join(fname, + self.filename_tmpl.format(idx)) + img = self.image_loader(frame_fname) + imgs.append(img) + buffer = np.array(imgs) + return buffer + + # handle temporal segments + average_duration = num_frames // self.num_segment + all_index = [] + if average_duration > 0: + if self.mode == 'validation': + all_index = list( + np.multiply( + list(range(self.num_segment)), average_duration) + + np.ones(self.num_segment, dtype=int) * + (average_duration // 2)) + else: + all_index = list( + np.multiply( + list(range(self.num_segment)), average_duration) + + np.random.randint(average_duration, size=self.num_segment)) + elif num_frames > self.num_segment: + if self.mode == 'validation': + all_index = list(range(self.num_segment)) + else: + all_index = list( + np.sort( + np.random.randint(num_frames, size=self.num_segment))) + else: + all_index = [0] * (self.num_segment - num_frames) + list( + range(num_frames)) + all_index = list(np.array(all_index) + self.start_idx) + imgs = [] + for idx in all_index: + frame_fname = os.path.join(fname, self.filename_tmpl.format(idx)) + img = self.image_loader(frame_fname) + imgs.append(img) + buffer = np.array(imgs) + return buffer + + def __len__(self): + if self.mode != 'test': + return len(self.dataset_samples) + else: + return len(self.test_dataset) + + +def spatial_sampling( + frames, + spatial_idx=-1, + min_scale=256, + max_scale=320, + crop_size=224, + random_horizontal_flip=True, + inverse_uniform_sampling=False, + aspect_ratio=None, + scale=None, + motion_shift=False, +): + """ + Perform spatial sampling on the given video frames. If spatial_idx is + -1, perform random scale, random crop, and random flip on the given + frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling + with the given spatial_idx. + Args: + frames (tensor): frames of images sampled from the video. The + dimension is `num frames` x `height` x `width` x `channel`. + spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, + or 2, perform left, center, right crop if width is larger than + height, and perform top, center, buttom crop if height is larger + than width. + min_scale (int): the minimal size of scaling. + max_scale (int): the maximal size of scaling. + crop_size (int): the size of height and width used to crop the + frames. + inverse_uniform_sampling (bool): if True, sample uniformly in + [1 / max_scale, 1 / min_scale] and take a reciprocal to get the + scale. If False, take a uniform sample from [min_scale, + max_scale]. + aspect_ratio (list): Aspect ratio range for resizing. + scale (list): Scale range for resizing. + motion_shift (bool): Whether to apply motion shift for resizing. + Returns: + frames (tensor): spatially sampled frames. + """ + assert spatial_idx in [-1, 0, 1, 2] + if spatial_idx == -1: + if aspect_ratio is None and scale is None: + frames, _ = video_transforms.random_short_side_scale_jitter( + images=frames, + min_size=min_scale, + max_size=max_scale, + inverse_uniform_sampling=inverse_uniform_sampling, + ) + frames, _ = video_transforms.random_crop(frames, crop_size) + else: + transform_func = ( + video_transforms.random_resized_crop_with_shift + if motion_shift else video_transforms.random_resized_crop) + frames = transform_func( + images=frames, + target_height=crop_size, + target_width=crop_size, + scale=scale, + ratio=aspect_ratio, + ) + if random_horizontal_flip: + frames, _ = video_transforms.horizontal_flip(0.5, frames) + else: + # The testing is deterministic and no jitter should be performed. + # min_scale, max_scale, and crop_size are expect to be the same. + assert len({min_scale, max_scale, crop_size}) == 1 + frames, _ = video_transforms.random_short_side_scale_jitter( + frames, min_scale, max_scale) + frames, _ = video_transforms.uniform_crop(frames, crop_size, + spatial_idx) + return frames + + +def tensor_normalize(tensor, mean, std): + """ + Normalize a given tensor by subtracting the mean and dividing the std. + Args: + tensor (tensor): tensor to normalize. + mean (tensor or list): mean value to subtract. + std (tensor or list): std to divide. + """ + if tensor.dtype == torch.uint8: + tensor = tensor.float() + tensor = tensor / 255.0 + if type(mean) == list: + mean = torch.tensor(mean) + if type(std) == list: + std = torch.tensor(std) + tensor = tensor - mean + tensor = tensor / std + return tensor diff --git a/v_cls/functional.py b/v_cls/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..e6bf3ea02af02e88259590c85d74377e8ac7b8a9 --- /dev/null +++ b/v_cls/functional.py @@ -0,0 +1,90 @@ +import numbers + +import cv2 +import numpy as np +import PIL +import torch + + +def _is_tensor_clip(clip): + return torch.is_tensor(clip) and clip.ndimension() == 4 + + +def crop_clip(clip, min_h, min_w, h, w): + if isinstance(clip[0], np.ndarray): + cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] + + elif isinstance(clip[0], PIL.Image.Image): + cropped = [ + img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return cropped + + +def resize_clip(clip, size, interpolation='bilinear'): + if isinstance(clip[0], np.ndarray): + if isinstance(size, numbers.Number): + im_h, im_w, im_c = clip[0].shape + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[0], size[1] + if interpolation == 'bilinear': + np_inter = cv2.INTER_LINEAR + else: + np_inter = cv2.INTER_NEAREST + scaled = [ + cv2.resize(img, size, interpolation=np_inter) for img in clip + ] + elif isinstance(clip[0], PIL.Image.Image): + if isinstance(size, numbers.Number): + im_w, im_h = clip[0].size + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + if interpolation == 'bilinear': + pil_inter = PIL.Image.BILINEAR + else: + pil_inter = PIL.Image.NEAREST + scaled = [img.resize(size, pil_inter) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return scaled + + +def get_resize_sizes(im_h, im_w, size): + if im_w < im_h: + ow = size + oh = int(size * im_h / im_w) + else: + oh = size + ow = int(size * im_w / im_h) + return oh, ow + + +def normalize(clip, mean, std, inplace=False): + if not _is_tensor_clip(clip): + raise TypeError('tensor is not a torch clip.') + + if not inplace: + clip = clip.clone() + + dtype = clip.dtype + mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) + std = torch.as_tensor(std, dtype=dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + + return clip diff --git a/v_cls/kinetics_400_labels.csv b/v_cls/kinetics_400_labels.csv new file mode 100644 index 0000000000000000000000000000000000000000..31c52a4c6036943b55f9780382a01b4fe6c64533 --- /dev/null +++ b/v_cls/kinetics_400_labels.csv @@ -0,0 +1,401 @@ +id,name +0,abseiling +1,air drumming +2,answering questions +3,applauding +4,applying cream +5,archery +6,arm wrestling +7,arranging flowers +8,assembling computer +9,auctioning +10,baby waking up +11,baking cookies +12,balloon blowing +13,bandaging +14,barbequing +15,bartending +16,beatboxing +17,bee keeping +18,belly dancing +19,bench pressing +20,bending back +21,bending metal +22,biking through snow +23,blasting sand +24,blowing glass +25,blowing leaves +26,blowing nose +27,blowing out candles +28,bobsledding +29,bookbinding +30,bouncing on trampoline +31,bowling +32,braiding hair +33,breading or breadcrumbing +34,breakdancing +35,brush painting +36,brushing hair +37,brushing teeth +38,building cabinet +39,building shed +40,bungee jumping +41,busking +42,canoeing or kayaking +43,capoeira +44,carrying baby +45,cartwheeling +46,carving pumpkin +47,catching fish +48,catching or throwing baseball +49,catching or throwing frisbee +50,catching or throwing softball +51,celebrating +52,changing oil +53,changing wheel +54,checking tires +55,cheerleading +56,chopping wood +57,clapping +58,clay pottery making +59,clean and jerk +60,cleaning floor +61,cleaning gutters +62,cleaning pool +63,cleaning shoes +64,cleaning toilet +65,cleaning windows +66,climbing a rope +67,climbing ladder +68,climbing tree +69,contact juggling +70,cooking chicken +71,cooking egg +72,cooking on campfire +73,cooking sausages +74,counting money +75,country line dancing +76,cracking neck +77,crawling baby +78,crossing river +79,crying +80,curling hair +81,cutting nails +82,cutting pineapple +83,cutting watermelon +84,dancing ballet +85,dancing charleston +86,dancing gangnam style +87,dancing macarena +88,deadlifting +89,decorating the christmas tree +90,digging +91,dining +92,disc golfing +93,diving cliff +94,dodgeball +95,doing aerobics +96,doing laundry +97,doing nails +98,drawing +99,dribbling basketball +100,drinking +101,drinking beer +102,drinking shots +103,driving car +104,driving tractor +105,drop kicking +106,drumming fingers +107,dunking basketball +108,dying hair +109,eating burger +110,eating cake +111,eating carrots +112,eating chips +113,eating doughnuts +114,eating hotdog +115,eating ice cream +116,eating spaghetti +117,eating watermelon +118,egg hunting +119,exercising arm +120,exercising with an exercise ball +121,extinguishing fire +122,faceplanting +123,feeding birds +124,feeding fish +125,feeding goats +126,filling eyebrows +127,finger snapping +128,fixing hair +129,flipping pancake +130,flying kite +131,folding clothes +132,folding napkins +133,folding paper +134,front raises +135,frying vegetables +136,garbage collecting +137,gargling +138,getting a haircut +139,getting a tattoo +140,giving or receiving award +141,golf chipping +142,golf driving +143,golf putting +144,grinding meat +145,grooming dog +146,grooming horse +147,gymnastics tumbling +148,hammer throw +149,headbanging +150,headbutting +151,high jump +152,high kick +153,hitting baseball +154,hockey stop +155,holding snake +156,hopscotch +157,hoverboarding +158,hugging +159,hula hooping +160,hurdling +161,hurling (sport) +162,ice climbing +163,ice fishing +164,ice skating +165,ironing +166,javelin throw +167,jetskiing +168,jogging +169,juggling balls +170,juggling fire +171,juggling soccer ball +172,jumping into pool +173,jumpstyle dancing +174,kicking field goal +175,kicking soccer ball +176,kissing +177,kitesurfing +178,knitting +179,krumping +180,laughing +181,laying bricks +182,long jump +183,lunge +184,making a cake +185,making a sandwich +186,making bed +187,making jewelry +188,making pizza +189,making snowman +190,making sushi +191,making tea +192,marching +193,massaging back +194,massaging feet +195,massaging legs +196,massaging person's head +197,milking cow +198,mopping floor +199,motorcycling +200,moving furniture +201,mowing lawn +202,news anchoring +203,opening bottle +204,opening present +205,paragliding +206,parasailing +207,parkour +208,passing American football (in game) +209,passing American football (not in game) +210,peeling apples +211,peeling potatoes +212,petting animal (not cat) +213,petting cat +214,picking fruit +215,planting trees +216,plastering +217,playing accordion +218,playing badminton +219,playing bagpipes +220,playing basketball +221,playing bass guitar +222,playing cards +223,playing cello +224,playing chess +225,playing clarinet +226,playing controller +227,playing cricket +228,playing cymbals +229,playing didgeridoo +230,playing drums +231,playing flute +232,playing guitar +233,playing harmonica +234,playing harp +235,playing ice hockey +236,playing keyboard +237,playing kickball +238,playing monopoly +239,playing organ +240,playing paintball +241,playing piano +242,playing poker +243,playing recorder +244,playing saxophone +245,playing squash or racquetball +246,playing tennis +247,playing trombone +248,playing trumpet +249,playing ukulele +250,playing violin +251,playing volleyball +252,playing xylophone +253,pole vault +254,presenting weather forecast +255,pull ups +256,pumping fist +257,pumping gas +258,punching bag +259,punching person (boxing) +260,push up +261,pushing car +262,pushing cart +263,pushing wheelchair +264,reading book +265,reading newspaper +266,recording music +267,riding a bike +268,riding camel +269,riding elephant +270,riding mechanical bull +271,riding mountain bike +272,riding mule +273,riding or walking with horse +274,riding scooter +275,riding unicycle +276,ripping paper +277,robot dancing +278,rock climbing +279,rock scissors paper +280,roller skating +281,running on treadmill +282,sailing +283,salsa dancing +284,sanding floor +285,scrambling eggs +286,scuba diving +287,setting table +288,shaking hands +289,shaking head +290,sharpening knives +291,sharpening pencil +292,shaving head +293,shaving legs +294,shearing sheep +295,shining shoes +296,shooting basketball +297,shooting goal (soccer) +298,shot put +299,shoveling snow +300,shredding paper +301,shuffling cards +302,side kick +303,sign language interpreting +304,singing +305,situp +306,skateboarding +307,ski jumping +308,skiing (not slalom or crosscountry) +309,skiing crosscountry +310,skiing slalom +311,skipping rope +312,skydiving +313,slacklining +314,slapping +315,sled dog racing +316,smoking +317,smoking hookah +318,snatch weight lifting +319,sneezing +320,sniffing +321,snorkeling +322,snowboarding +323,snowkiting +324,snowmobiling +325,somersaulting +326,spinning poi +327,spray painting +328,spraying +329,springboard diving +330,squat +331,sticking tongue out +332,stomping grapes +333,stretching arm +334,stretching leg +335,strumming guitar +336,surfing crowd +337,surfing water +338,sweeping floor +339,swimming backstroke +340,swimming breast stroke +341,swimming butterfly stroke +342,swing dancing +343,swinging legs +344,swinging on something +345,sword fighting +346,tai chi +347,taking a shower +348,tango dancing +349,tap dancing +350,tapping guitar +351,tapping pen +352,tasting beer +353,tasting food +354,testifying +355,texting +356,throwing axe +357,throwing ball +358,throwing discus +359,tickling +360,tobogganing +361,tossing coin +362,tossing salad +363,training dog +364,trapezing +365,trimming or shaving beard +366,trimming trees +367,triple jump +368,tying bow tie +369,tying knot (not on a tie) +370,tying tie +371,unboxing +372,unloading truck +373,using computer +374,using remote controller (not gaming) +375,using segway +376,vault +377,waiting in line +378,walking the dog +379,washing dishes +380,washing feet +381,washing hair +382,washing hands +383,water skiing +384,water sliding +385,watering plants +386,waxing back +387,waxing chest +388,waxing eyebrows +389,waxing legs +390,weaving basket +391,welding +392,whistling +393,windsurfing +394,wrapping present +395,wrestling +396,writing +397,yawning +398,yoga +399,zumba diff --git a/v_cls/kinetics_600_labels.csv b/v_cls/kinetics_600_labels.csv new file mode 100644 index 0000000000000000000000000000000000000000..41a5df0547c3cc1f59592f6ce00702043d8ad32e --- /dev/null +++ b/v_cls/kinetics_600_labels.csv @@ -0,0 +1,601 @@ +id,name +0,abseiling +1,acting in play +2,adjusting glasses +3,air drumming +4,alligator wrestling +5,answering questions +6,applauding +7,applying cream +8,archaeological excavation +9,archery +10,arguing +11,arm wrestling +12,arranging flowers +13,assembling bicycle +14,assembling computer +15,attending conference +16,auctioning +17,backflip (human) +18,baking cookies +19,bandaging +20,barbequing +21,bartending +22,base jumping +23,bathing dog +24,battle rope training +25,beatboxing +26,bee keeping +27,belly dancing +28,bench pressing +29,bending back +30,bending metal +31,biking through snow +32,blasting sand +33,blowdrying hair +34,blowing bubble gum +35,blowing glass +36,blowing leaves +37,blowing nose +38,blowing out candles +39,bobsledding +40,bodysurfing +41,bookbinding +42,bottling +43,bouncing on bouncy castle +44,bouncing on trampoline +45,bowling +46,braiding hair +47,breading or breadcrumbing +48,breakdancing +49,breaking boards +50,breathing fire +51,brush painting +52,brushing hair +53,brushing teeth +54,building cabinet +55,building lego +56,building sandcastle +57,building shed +58,bull fighting +59,bulldozing +60,bungee jumping +61,burping +62,busking +63,calculating +64,calligraphy +65,canoeing or kayaking +66,capoeira +67,capsizing +68,card stacking +69,card throwing +70,carrying baby +71,cartwheeling +72,carving ice +73,carving pumpkin +74,casting fishing line +75,catching fish +76,catching or throwing baseball +77,catching or throwing frisbee +78,catching or throwing softball +79,celebrating +80,changing gear in car +81,changing oil +82,changing wheel (not on bike) +83,checking tires +84,cheerleading +85,chewing gum +86,chiseling stone +87,chiseling wood +88,chopping meat +89,chopping vegetables +90,chopping wood +91,clam digging +92,clapping +93,clay pottery making +94,clean and jerk +95,cleaning gutters +96,cleaning pool +97,cleaning shoes +98,cleaning toilet +99,cleaning windows +100,climbing a rope +101,climbing ladder +102,climbing tree +103,coloring in +104,combing hair +105,contact juggling +106,contorting +107,cooking egg +108,cooking on campfire +109,cooking sausages (not on barbeque) +110,cooking scallops +111,cosplaying +112,counting money +113,country line dancing +114,cracking back +115,cracking knuckles +116,cracking neck +117,crawling baby +118,crossing eyes +119,crossing river +120,crying +121,cumbia +122,curling (sport) +123,curling hair +124,cutting apple +125,cutting nails +126,cutting orange +127,cutting pineapple +128,cutting watermelon +129,dancing ballet +130,dancing charleston +131,dancing gangnam style +132,dancing macarena +133,deadlifting +134,decorating the christmas tree +135,delivering mail +136,dining +137,directing traffic +138,disc golfing +139,diving cliff +140,docking boat +141,dodgeball +142,doing aerobics +143,doing jigsaw puzzle +144,doing laundry +145,doing nails +146,drawing +147,dribbling basketball +148,drinking shots +149,driving car +150,driving tractor +151,drooling +152,drop kicking +153,drumming fingers +154,dumpster diving +155,dunking basketball +156,dyeing eyebrows +157,dyeing hair +158,eating burger +159,eating cake +160,eating carrots +161,eating chips +162,eating doughnuts +163,eating hotdog +164,eating ice cream +165,eating spaghetti +166,eating watermelon +167,egg hunting +168,embroidering +169,exercising with an exercise ball +170,extinguishing fire +171,faceplanting +172,falling off bike +173,falling off chair +174,feeding birds +175,feeding fish +176,feeding goats +177,fencing (sport) +178,fidgeting +179,finger snapping +180,fixing bicycle +181,fixing hair +182,flint knapping +183,flipping pancake +184,fly tying +185,flying kite +186,folding clothes +187,folding napkins +188,folding paper +189,front raises +190,frying vegetables +191,geocaching +192,getting a haircut +193,getting a piercing +194,getting a tattoo +195,giving or receiving award +196,gold panning +197,golf chipping +198,golf driving +199,golf putting +200,gospel singing in church +201,grinding meat +202,grooming dog +203,grooming horse +204,gymnastics tumbling +205,hammer throw +206,hand washing clothes +207,head stand +208,headbanging +209,headbutting +210,high jump +211,high kick +212,historical reenactment +213,hitting baseball +214,hockey stop +215,holding snake +216,home roasting coffee +217,hopscotch +218,hoverboarding +219,huddling +220,hugging (not baby) +221,hugging baby +222,hula hooping +223,hurdling +224,hurling (sport) +225,ice climbing +226,ice fishing +227,ice skating +228,ice swimming +229,inflating balloons +230,installing carpet +231,ironing +232,ironing hair +233,javelin throw +234,jaywalking +235,jetskiing +236,jogging +237,juggling balls +238,juggling fire +239,juggling soccer ball +240,jumping bicycle +241,jumping into pool +242,jumping jacks +243,jumpstyle dancing +244,karaoke +245,kicking field goal +246,kicking soccer ball +247,kissing +248,kitesurfing +249,knitting +250,krumping +251,land sailing +252,laughing +253,lawn mower racing +254,laying bricks +255,laying concrete +256,laying stone +257,laying tiles +258,leatherworking +259,licking +260,lifting hat +261,lighting fire +262,lock picking +263,long jump +264,longboarding +265,looking at phone +266,luge +267,lunge +268,making a cake +269,making a sandwich +270,making balloon shapes +271,making bubbles +272,making cheese +273,making horseshoes +274,making jewelry +275,making paper aeroplanes +276,making pizza +277,making snowman +278,making sushi +279,making tea +280,making the bed +281,marching +282,marriage proposal +283,massaging back +284,massaging feet +285,massaging legs +286,massaging neck +287,massaging person's head +288,milking cow +289,moon walking +290,mopping floor +291,mosh pit dancing +292,motorcycling +293,mountain climber (exercise) +294,moving furniture +295,mowing lawn +296,mushroom foraging +297,needle felting +298,news anchoring +299,opening bottle (not wine) +300,opening door +301,opening present +302,opening refrigerator +303,opening wine bottle +304,packing +305,paragliding +306,parasailing +307,parkour +308,passing American football (in game) +309,passing american football (not in game) +310,passing soccer ball +311,peeling apples +312,peeling potatoes +313,person collecting garbage +314,petting animal (not cat) +315,petting cat +316,photobombing +317,photocopying +318,picking fruit +319,pillow fight +320,pinching +321,pirouetting +322,planing wood +323,planting trees +324,plastering +325,playing accordion +326,playing badminton +327,playing bagpipes +328,playing basketball +329,playing bass guitar +330,playing beer pong +331,playing blackjack +332,playing cello +333,playing chess +334,playing clarinet +335,playing controller +336,playing cricket +337,playing cymbals +338,playing darts +339,playing didgeridoo +340,playing dominoes +341,playing drums +342,playing field hockey +343,playing flute +344,playing gong +345,playing guitar +346,playing hand clapping games +347,playing harmonica +348,playing harp +349,playing ice hockey +350,playing keyboard +351,playing kickball +352,playing laser tag +353,playing lute +354,playing maracas +355,playing marbles +356,playing monopoly +357,playing netball +358,playing ocarina +359,playing organ +360,playing paintball +361,playing pan pipes +362,playing piano +363,playing pinball +364,playing ping pong +365,playing poker +366,playing polo +367,playing recorder +368,playing rubiks cube +369,playing saxophone +370,playing scrabble +371,playing squash or racquetball +372,playing tennis +373,playing trombone +374,playing trumpet +375,playing ukulele +376,playing violin +377,playing volleyball +378,playing with trains +379,playing xylophone +380,poking bellybutton +381,pole vault +382,polishing metal +383,popping balloons +384,pouring beer +385,preparing salad +386,presenting weather forecast +387,pull ups +388,pumping fist +389,pumping gas +390,punching bag +391,punching person (boxing) +392,push up +393,pushing car +394,pushing cart +395,pushing wheelbarrow +396,pushing wheelchair +397,putting in contact lenses +398,putting on eyeliner +399,putting on foundation +400,putting on lipstick +401,putting on mascara +402,putting on sari +403,putting on shoes +404,raising eyebrows +405,reading book +406,reading newspaper +407,recording music +408,repairing puncture +409,riding a bike +410,riding camel +411,riding elephant +412,riding mechanical bull +413,riding mule +414,riding or walking with horse +415,riding scooter +416,riding snow blower +417,riding unicycle +418,ripping paper +419,roasting marshmallows +420,roasting pig +421,robot dancing +422,rock climbing +423,rock scissors paper +424,roller skating +425,rolling pastry +426,rope pushdown +427,running on treadmill +428,sailing +429,salsa dancing +430,sanding floor +431,sausage making +432,sawing wood +433,scrambling eggs +434,scrapbooking +435,scrubbing face +436,scuba diving +437,separating eggs +438,setting table +439,sewing +440,shaking hands +441,shaking head +442,shaping bread dough +443,sharpening knives +444,sharpening pencil +445,shaving head +446,shaving legs +447,shearing sheep +448,shining flashlight +449,shining shoes +450,shooting basketball +451,shooting goal (soccer) +452,shopping +453,shot put +454,shoveling snow +455,shucking oysters +456,shuffling cards +457,shuffling feet +458,side kick +459,sign language interpreting +460,singing +461,sipping cup +462,situp +463,skateboarding +464,ski jumping +465,skiing crosscountry +466,skiing mono +467,skiing slalom +468,skipping rope +469,skipping stone +470,skydiving +471,slacklining +472,slapping +473,sled dog racing +474,sleeping +475,smashing +476,smelling feet +477,smoking +478,smoking hookah +479,smoking pipe +480,snatch weight lifting +481,sneezing +482,snorkeling +483,snowboarding +484,snowkiting +485,snowmobiling +486,somersaulting +487,spelunking +488,spinning poi +489,spray painting +490,springboard diving +491,square dancing +492,squat +493,standing on hands +494,staring +495,steer roping +496,sticking tongue out +497,stomping grapes +498,stretching arm +499,stretching leg +500,sucking lolly +501,surfing crowd +502,surfing water +503,sweeping floor +504,swimming backstroke +505,swimming breast stroke +506,swimming butterfly stroke +507,swimming front crawl +508,swing dancing +509,swinging baseball bat +510,swinging on something +511,sword fighting +512,sword swallowing +513,tackling +514,tagging graffiti +515,tai chi +516,talking on cell phone +517,tango dancing +518,tap dancing +519,tapping guitar +520,tapping pen +521,tasting beer +522,tasting food +523,tasting wine +524,testifying +525,texting +526,threading needle +527,throwing axe +528,throwing ball (not baseball or American football) +529,throwing discus +530,throwing knife +531,throwing snowballs +532,throwing tantrum +533,throwing water balloon +534,tickling +535,tie dying +536,tightrope walking +537,tiptoeing +538,tobogganing +539,tossing coin +540,training dog +541,trapezing +542,trimming or shaving beard +543,trimming shrubs +544,trimming trees +545,triple jump +546,twiddling fingers +547,tying bow tie +548,tying knot (not on a tie) +549,tying necktie +550,tying shoe laces +551,unboxing +552,unloading truck +553,using a microscope +554,using a paint roller +555,using a power drill +556,using a sledge hammer +557,using a wrench +558,using atm +559,using bagging machine +560,using circular saw +561,using inhaler +562,using puppets +563,using remote controller (not gaming) +564,using segway +565,vacuuming floor +566,visiting the zoo +567,wading through mud +568,wading through water +569,waiting in line +570,waking up +571,walking the dog +572,walking through snow +573,washing dishes +574,washing feet +575,washing hair +576,washing hands +577,watching tv +578,water skiing +579,water sliding +580,watering plants +581,waving hand +582,waxing back +583,waxing chest +584,waxing eyebrows +585,waxing legs +586,weaving basket +587,weaving fabric +588,welding +589,whistling +590,windsurfing +591,winking +592,wood burning (art) +593,wrapping present +594,wrestling +595,writing +596,yarn spinning +597,yawning +598,yoga +599,zumba diff --git a/v_cls/loader.py b/v_cls/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..cb8a31ccc0ac0928f1717f07f1425b5496cdcc32 --- /dev/null +++ b/v_cls/loader.py @@ -0,0 +1,54 @@ +import io + +import cv2 +import decord +import numpy as np +from decord import VideoReader, cpu + +try: + from petrel_client.client import Client + petrel_backend_imported = True +except (ImportError, ModuleNotFoundError): + petrel_backend_imported = False + + +def get_video_loader(use_petrel_backend: bool = True, + enable_mc: bool = True, + conf_path: str = None): + if petrel_backend_imported and use_petrel_backend: + _client = Client(conf_path=conf_path, enable_mc=enable_mc) + else: + _client = None + + def _loader(video_path): + if _client is not None and 's3:' in video_path: + video_path = io.BytesIO(_client.get(video_path)) + + decord.bridge.set_bridge('native') + vr = VideoReader(video_path, num_threads=1, ctx=cpu(0)) + return vr + + return _loader + + +def get_image_loader(use_petrel_backend: bool = True, + enable_mc: bool = True, + conf_path: str = None): + if petrel_backend_imported and use_petrel_backend: + _client = Client(conf_path=conf_path, enable_mc=enable_mc) + else: + _client = None + + def _loader(frame_path): + if _client is not None and 's3:' in frame_path: + img_bytes = _client.get(frame_path) + else: + with open(frame_path, 'rb') as f: + img_bytes = f.read() + + img_np = np.frombuffer(img_bytes, np.uint8) + img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) + return img + + return _loader diff --git a/v_cls/masking_generator.py b/v_cls/masking_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c858aaf61ab08643c5681f3e29b3baff0461e21e --- /dev/null +++ b/v_cls/masking_generator.py @@ -0,0 +1,113 @@ +# -------------------------------------------------------- +# Based on BEiT, timm, DINO and DeiT code bases +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/facebookresearch/deit +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import numpy as np + + +class Cell(): + + def __init__(self, num_masks, num_patches): + self.num_masks = num_masks + self.num_patches = num_patches + self.size = num_masks + num_patches + self.queue = np.hstack([np.ones(num_masks), np.zeros(num_patches)]) + self.queue_ptr = 0 + + def set_ptr(self, pos=-1): + self.queue_ptr = np.random.randint(self.size) if pos < 0 else pos + + def get_cell(self): + cell_idx = (np.arange(self.size) + self.queue_ptr) % self.size + return self.queue[cell_idx] + + def run_cell(self): + self.queue_ptr += 1 + + +class RandomMaskingGenerator: + + def __init__(self, input_size, mask_ratio): + if not isinstance(input_size, tuple): + input_size = (input_size, ) * 3 + + self.frames, self.height, self.width = input_size + + self.num_patches = self.frames * self.height * self.width # 8x14x14 + self.num_mask = int(mask_ratio * self.num_patches) + + def __repr__(self): + repr_str = "Mask: total patches {}, mask patches {}".format( + self.num_patches, self.num_mask) + return repr_str + + def __call__(self): + mask = np.hstack([ + np.zeros(self.num_patches - self.num_mask), + np.ones(self.num_mask), + ]) + np.random.shuffle(mask) + return mask # [196*8] + + +class TubeMaskingGenerator: + + def __init__(self, input_size, mask_ratio): + self.frames, self.height, self.width = input_size + self.num_patches_per_frame = self.height * self.width # 14x14 + self.total_patches = self.frames * self.num_patches_per_frame + self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) + self.total_masks = self.frames * self.num_masks_per_frame + + def __repr__(self): + repr_str = "Tube Masking: total patches {}, mask patches {}".format( + self.total_patches, self.total_masks) + return repr_str + + def __call__(self): + mask_per_frame = np.hstack([ + np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), + np.ones(self.num_masks_per_frame), + ]) + np.random.shuffle(mask_per_frame) + mask = np.tile(mask_per_frame, (self.frames, 1)) + return mask # [196*8] + + +class RunningCellMaskingGenerator: + + def __init__(self, input_size, mask_ratio=0.5): + self.frames, self.height, self.width = input_size + self.mask_ratio = mask_ratio + + num_masks_per_cell = int(4 * self.mask_ratio) + assert 0 < num_masks_per_cell < 4 + num_patches_per_cell = 4 - num_masks_per_cell + + self.cell = Cell(num_masks_per_cell, num_patches_per_cell) + self.cell_size = self.cell.size + + mask_list = [] + for ptr_pos in range(self.cell_size): + self.cell.set_ptr(ptr_pos) + mask = [] + for _ in range(self.frames): + self.cell.run_cell() + mask_unit = self.cell.get_cell().reshape(2, 2) + mask_map = np.tile(mask_unit, + [self.height // 2, self.width // 2]) + mask.append(mask_map.flatten()) + mask = np.stack(mask, axis=0) + mask_list.append(mask) + self.all_mask_maps = np.stack(mask_list, axis=0) + + def __repr__(self): + repr_str = f"Running Cell Masking with mask ratio {self.mask_ratio}" + return repr_str + + def __call__(self): + mask = self.all_mask_maps[np.random.randint(self.cell_size)] + return np.copy(mask) diff --git a/v_cls/pretrain_datasets.py b/v_cls/pretrain_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..22ea76f2f5bd19b0313fabbcd40f80bd10bbe65a --- /dev/null +++ b/v_cls/pretrain_datasets.py @@ -0,0 +1,483 @@ +import json +import os +import random + +import numpy as np +import torch +from PIL import Image + +from .loader import get_image_loader, get_video_loader + + + +class HybridVideoMAE(torch.utils.data.Dataset): + """Load your own videomae pretraining dataset. + Parameters + ---------- + root : str, required. + Path to the root folder storing the dataset. + setting : str, required. + A text file describing the dataset, each line per video sample. + There are four items in each line: + (1) video path; (2) start_idx, (3) total frames and (4) video label. + for pre-train video data + total frames < 0, start_idx and video label meaningless + for pre-train rawframe data + video label meaningless + train : bool, default True. + Whether to load the training or validation set. + test_mode : bool, default False. + Whether to perform evaluation on the test set. + Usually there is three-crop or ten-crop evaluation strategy involved. + name_pattern : str, default 'img_{:05}.jpg'. + The naming pattern of the decoded video frames. + For example, img_00012.jpg. + video_ext : str, default 'mp4'. + If video_loader is set to True, please specify the video format accordinly. + is_color : bool, default True. + Whether the loaded image is color or grayscale. + modality : str, default 'rgb'. + Input modalities, we support only rgb video frames for now. + Will add support for rgb difference image and optical flow image later. + num_segments : int, default 1. + Number of segments to evenly divide the video into clips. + A useful technique to obtain global video-level information. + Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. + num_crop : int, default 1. + Number of crops for each image. default is 1. + Common choices are three crops and ten crops during evaluation. + new_length : int, default 1. + The length of input video clip. Default is a single image, but it can be multiple video frames. + For example, new_length=16 means we will extract a video clip of consecutive 16 frames. + new_step : int, default 1. + Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. + new_step=2 means we will extract a video clip of every other frame. + transform : function, default None. + A function that takes data and label and transforms them. + temporal_jitter : bool, default False. + Whether to temporally jitter if new_step > 1. + lazy_init : bool, default False. + If set to True, build a dataset instance without loading any dataset. + num_sample : int, default 1. + Number of sampled views for Repeated Augmentation. + """ + + def __init__(self, + root, + setting, + train=True, + test_mode=False, + name_pattern='img_{:05}.jpg', + video_ext='mp4', + is_color=True, + modality='rgb', + num_segments=1, + num_crop=1, + new_length=1, + new_step=1, + transform=None, + temporal_jitter=False, + lazy_init=False, + num_sample=1): + + super(HybridVideoMAE, self).__init__() + self.root = root + self.setting = setting + self.train = train + self.test_mode = test_mode + self.is_color = is_color + self.modality = modality + self.num_segments = num_segments + self.num_crop = num_crop + self.new_length = new_length + self.new_step = new_step + self.skip_length = self.new_length * self.new_step + self.temporal_jitter = temporal_jitter + self.name_pattern = name_pattern + self.video_ext = video_ext + self.transform = transform + self.lazy_init = lazy_init + self.num_sample = num_sample + + # NOTE: + # for hybrid train + # different frame naming formats are used for different datasets + # should MODIFY the fname_tmpl to your own situation + self.ava_fname_tmpl = 'image_{:06}.jpg' + self.ssv2_fname_tmpl = 'img_{:05}.jpg' + + # NOTE: + # we set sampling_rate = 2 for ssv2 + # thus being consistent with the fine-tuning stage + # Note that the ssv2 we use is decoded to frames at 12 fps; + # if decoded at 24 fps, the sample interval should be 4. + self.ssv2_skip_length = self.new_length * 2 + self.orig_skip_length = self.skip_length + + self.video_loader = get_video_loader() + self.image_loader = get_image_loader() + + if not self.lazy_init: + self.clips = self._make_dataset(root, setting) + if len(self.clips) == 0: + raise ( + RuntimeError("Found 0 video clips in subfolders of: " + + root + "\n" + "Check your data directory (opt.data-dir).")) + + def __getitem__(self, index): + try: + video_name, start_idx, total_frame = self.clips[index] + self.skip_length = self.orig_skip_length + + if total_frame < 0: + decord_vr = self.video_loader(video_name) + duration = len(decord_vr) + + segment_indices, skip_offsets = self._sample_train_indices( + duration) + frame_id_list = self.get_frame_id_list(duration, + segment_indices, + skip_offsets) + video_data = decord_vr.get_batch(frame_id_list).asnumpy() + images = [ + Image.fromarray(video_data[vid, :, :, :]).convert('RGB') + for vid, _ in enumerate(frame_id_list) + ] + + else: + # ssv2 & ava & other rawframe dataset + if 'SomethingV2' in video_name: + self.skip_length = self.ssv2_skip_length + fname_tmpl = self.ssv2_fname_tmpl + elif 'AVA2.2' in video_name: + fname_tmpl = self.ava_fname_tmpl + else: + fname_tmpl = self.name_pattern + + segment_indices, skip_offsets = self._sample_train_indices( + total_frame) + frame_id_list = self.get_frame_id_list(total_frame, + segment_indices, + skip_offsets) + + images = [] + for idx in frame_id_list: + frame_fname = os.path.join( + video_name, fname_tmpl.format(idx + start_idx)) + img = self.image_loader(frame_fname) + img = Image.fromarray(img) + images.append(img) + + except Exception as e: + print("Failed to load video from {} with error {}".format( + video_name, e)) + index = random.randint(0, len(self.clips) - 1) + return self.__getitem__(index) + + if self.num_sample > 1: + process_data_list = [] + encoder_mask_list = [] + decoder_mask_list = [] + for _ in range(self.num_sample): + process_data, encoder_mask, decoder_mask = self.transform( + (images, None)) + process_data = process_data.view( + (self.new_length, 3) + process_data.size()[-2:]).transpose( + 0, 1) + process_data_list.append(process_data) + encoder_mask_list.append(encoder_mask) + decoder_mask_list.append(decoder_mask) + return process_data_list, encoder_mask_list, decoder_mask_list + else: + process_data, encoder_mask, decoder_mask = self.transform( + (images, None)) + # T*C,H,W -> T,C,H,W -> C,T,H,W + process_data = process_data.view( + (self.new_length, 3) + process_data.size()[-2:]).transpose( + 0, 1) + return process_data, encoder_mask, decoder_mask + + def __len__(self): + return len(self.clips) + + def _make_dataset(self, root, setting): + if not os.path.exists(setting): + raise (RuntimeError( + "Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " + % (setting))) + clips = [] + with open(setting) as split_f: + data = split_f.readlines() + for line in data: + line_info = line.split(' ') + # line format: video_path, video_duration, video_label + if len(line_info) < 2: + raise (RuntimeError( + 'Video input format is not correct, missing one or more element. %s' + % line)) + clip_path = os.path.join(root, line_info[0]) + start_idx = int(line_info[1]) + total_frame = int(line_info[2]) + item = (clip_path, start_idx, total_frame) + clips.append(item) + return clips + + def _sample_train_indices(self, num_frames): + average_duration = (num_frames - self.skip_length + + 1) // self.num_segments + if average_duration > 0: + offsets = np.multiply( + list(range(self.num_segments)), average_duration) + offsets = offsets + np.random.randint( + average_duration, size=self.num_segments) + elif num_frames > max(self.num_segments, self.skip_length): + offsets = np.sort( + np.random.randint( + num_frames - self.skip_length + 1, size=self.num_segments)) + else: + offsets = np.zeros((self.num_segments, )) + + if self.temporal_jitter: + skip_offsets = np.random.randint( + self.new_step, size=self.skip_length // self.new_step) + else: + skip_offsets = np.zeros( + self.skip_length // self.new_step, dtype=int) + return offsets + 1, skip_offsets + + def get_frame_id_list(self, duration, indices, skip_offsets): + frame_id_list = [] + for seg_ind in indices: + offset = int(seg_ind) + for i, _ in enumerate(range(0, self.skip_length, self.new_step)): + if offset + skip_offsets[i] <= duration: + frame_id = offset + skip_offsets[i] - 1 + else: + frame_id = offset - 1 + frame_id_list.append(frame_id) + if offset + self.new_step < duration: + offset += self.new_step + return frame_id_list + +class VideoMAE(torch.utils.data.Dataset): + """Load your own videomae pretraining dataset. + Parameters + ---------- + root : str, required. + Path to the root folder storing the dataset. + setting : str, required. + A text file describing the dataset, each line per video sample. + There are four items in each line: + (1) video path; (2) start_idx, (3) total frames and (4) video label. + for pre-train video data + total frames < 0, start_idx and video label meaningless + for pre-train rawframe data + video label meaningless + train : bool, default True. + Whether to load the training or validation set. + test_mode : bool, default False. + Whether to perform evaluation on the test set. + Usually there is three-crop or ten-crop evaluation strategy involved. + name_pattern : str, default 'img_{:05}.jpg'. + The naming pattern of the decoded video frames. + For example, img_00012.jpg. + video_ext : str, default 'mp4'. + If video_loader is set to True, please specify the video format accordinly. + is_color : bool, default True. + Whether the loaded image is color or grayscale. + modality : str, default 'rgb'. + Input modalities, we support only rgb video frames for now. + Will add support for rgb difference image and optical flow image later. + num_segments : int, default 1. + Number of segments to evenly divide the video into clips. + A useful technique to obtain global video-level information. + Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. + num_crop : int, default 1. + Number of crops for each image. default is 1. + Common choices are three crops and ten crops during evaluation. + new_length : int, default 1. + The length of input video clip. Default is a single image, but it can be multiple video frames. + For example, new_length=16 means we will extract a video clip of consecutive 16 frames. + new_step : int, default 1. + Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. + new_step=2 means we will extract a video clip of every other frame. + transform : function, default None. + A function that takes data and label and transforms them. + temporal_jitter : bool, default False. + Whether to temporally jitter if new_step > 1. + lazy_init : bool, default False. + If set to True, build a dataset instance without loading any dataset. + num_sample : int, default 1. + Number of sampled views for Repeated Augmentation. + """ + + def __init__(self, + root, + setting, + train=True, + test_mode=False, + name_pattern='img_{:05}.jpg', + video_ext='mp4', + is_color=True, + modality='rgb', + num_segments=1, + num_crop=1, + new_length=1, + new_step=1, + transform=None, + temporal_jitter=False, + lazy_init=False, + num_sample=1): + + super(VideoMAE, self).__init__() + self.root = root + self.setting = setting + self.train = train + self.test_mode = test_mode + self.is_color = is_color + self.modality = modality + self.num_segments = num_segments + self.num_crop = num_crop + self.new_length = new_length + self.new_step = new_step + self.skip_length = self.new_length * self.new_step + self.temporal_jitter = temporal_jitter + self.name_pattern = name_pattern + self.video_ext = video_ext + self.transform = transform + self.lazy_init = lazy_init + self.num_sample = num_sample + + self.video_loader = get_video_loader() + self.image_loader = get_image_loader() + + if not self.lazy_init: + # self.anno_path = '/apdcephfs_cq3/share_1311970/A_Youtube/coco_vat_vat0_11_all_id_rootfolder_clsidx_spacy.json' + # self.video_root = '/apdcephfs_cq3/share_1311970/A_Youtube/coco_vat_vat0_11_all_id_rootfolder_clsidx_spacy' + # with open(self.anno_path, 'r') as f: + # anno = eval(json.load(f)) + # keys = list(anno.keys()) + # self.clips = [(os.path.join(self.video_root, key + '.mp4'), anno[key]['idx_list']) for key in + # keys] + + self.anno_path = '/apdcephfs_cq3/share_1311970/A_Youtube/category_idlist_dict.json' + self.video_root = '/apdcephfs_cq3/share_1311970/A_Youtube' + with open(self.anno_path, 'r') as f: + content = json.load(f) + clips = content['Sports'] + self.clips = [[os.path.join(self.video_root, v, k + '.mp4'), -1] for k, v in clips.items()] + + + + + if len(self.clips) == 0: + raise (RuntimeError("Found 0 video clips in subfolders of: " + root + "\n")) + + def __getitem__(self, index): + try: + video_name, start_idx = self.clips[index] + decord_vr = self.video_loader(video_name) + duration = len(decord_vr) + + segment_indices, skip_offsets = self._sample_train_indices( + duration) + frame_id_list = self.get_frame_id_list(duration, + segment_indices, + skip_offsets) + video_data = decord_vr.get_batch(frame_id_list).asnumpy() + images = [ + Image.fromarray(video_data[vid, :, :, :]).convert('RGB') + for vid, _ in enumerate(frame_id_list) + ] + + except Exception as e: + print("Failed to load video from {} with error {}".format( + video_name, e)) + index = random.randint(0, len(self.clips) - 1) + return self.__getitem__(index) + + if self.num_sample > 1: + process_data_list = [] + encoder_mask_list = [] + decoder_mask_list = [] + for _ in range(self.num_sample): + process_data, encoder_mask, decoder_mask = self.transform( + (images, None)) + process_data = process_data.view( + (self.new_length, 3) + process_data.size()[-2:]).transpose( + 0, 1) + process_data_list.append(process_data) + encoder_mask_list.append(encoder_mask) + decoder_mask_list.append(decoder_mask) + return process_data_list, encoder_mask_list, decoder_mask_list + else: + process_data, encoder_mask, decoder_mask = self.transform( + (images, None)) + # T*C,H,W -> T,C,H,W -> C,T,H,W + process_data = process_data.view( + (self.new_length, 3) + process_data.size()[-2:]).transpose( + 0, 1) + return process_data, encoder_mask, decoder_mask + + def __len__(self): + return len(self.clips) + + def _make_dataset(self, root, setting): + if not os.path.exists(setting): + raise (RuntimeError( + "Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " + % (setting))) + clips = [] + with open(setting) as split_f: + data = split_f.readlines() + for line in data: + line_info = line.split(' ') + # line format: video_path, start_idx, total_frames + if len(line_info) < 3: + raise (RuntimeError( + 'Video input format is not correct, missing one or more element. %s' + % line)) + clip_path = os.path.join(root, line_info[0]) + start_idx = int(line_info[1]) + total_frame = int(line_info[2]) + item = (clip_path, start_idx, total_frame) + clips.append(item) + return clips + + def _sample_train_indices(self, num_frames): + average_duration = (num_frames - self.skip_length + + 1) // self.num_segments + if average_duration > 0: + offsets = np.multiply( + list(range(self.num_segments)), average_duration) + offsets = offsets + np.random.randint( + average_duration, size=self.num_segments) + elif num_frames > max(self.num_segments, self.skip_length): + offsets = np.sort( + np.random.randint( + num_frames - self.skip_length + 1, size=self.num_segments)) + else: + offsets = np.zeros((self.num_segments, )) + + if self.temporal_jitter: + skip_offsets = np.random.randint( + self.new_step, size=self.skip_length // self.new_step) + else: + skip_offsets = np.zeros( + self.skip_length // self.new_step, dtype=int) + return offsets + 1, skip_offsets + + def get_frame_id_list(self, duration, indices, skip_offsets): + frame_id_list = [] + for seg_ind in indices: + offset = int(seg_ind) + for i, _ in enumerate(range(0, self.skip_length, self.new_step)): + if offset + skip_offsets[i] <= duration: + frame_id = offset + skip_offsets[i] - 1 + else: + frame_id = offset - 1 + frame_id_list.append(frame_id) + if offset + self.new_step < duration: + offset += self.new_step + return frame_id_list diff --git a/v_cls/rand_augment.py b/v_cls/rand_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..711701229a1828a2f2a22fba0ed128e764daa60c --- /dev/null +++ b/v_cls/rand_augment.py @@ -0,0 +1,521 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py +pulished under an Apache License 2.0. + +COMMENT FROM ORIGINAL: +AutoAugment, RandAugment, and AugMix for PyTorch +This code implements the searched ImageNet policies with various tweaks and +improvements and does not include any of the search code. AA and RA +Implementation adapted from: + https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py +AugMix adapted from: + https://github.com/google-research/augmix +Papers: + AutoAugment: Learning Augmentation Policies from Data + https://arxiv.org/abs/1805.09501 + Learning Data Augmentation Strategies for Object Detection + https://arxiv.org/abs/1906.11172 + RandAugment: Practical automated data augmentation... + https://arxiv.org/abs/1909.13719 + AugMix: A Simple Data Processing Method to Improve Robustness and + Uncertainty https://arxiv.org/abs/1912.02781 + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import math +import random +import re + +import numpy as np +import PIL +from PIL import Image, ImageEnhance, ImageOps + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10.0 + +_HPARAMS_DEFAULT = { + "translate_const": 250, + "img_mean": _FILL, +} + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop("resample", Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if "fillcolor" in kwargs and _PIL_VER < (5, 0): + kwargs.pop("fillcolor") + kwargs["resample"] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), + **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), + **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), + **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), + **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), + **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), + **kwargs) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], + -rotn_center[1] - post_trans[1], + matrix, + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs["resample"]) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30.0 + level = _randomly_negate(level) + return (level, ) + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return ((level / _MAX_LEVEL) * 1.8 + 0.1, ) + + +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * 0.9 + level = 1.0 + _randomly_negate(level) + return (level, ) + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level, ) + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams["translate_const"] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level, ) + + +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get("translate_pct", 0.45) + level = (level / _MAX_LEVEL) * translate_pct + level = _randomly_negate(level) + return (level, ) + + +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4), ) + + +def _posterize_increasing_level_to_arg(level, hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return (4 - _posterize_level_to_arg(level, hparams)[0], ) + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4) + 4, ) + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 256), ) + + +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return (256 - _solarize_level_to_arg(level, _hparams)[0], ) + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return (int((level / _MAX_LEVEL) * 110), ) + + +LEVEL_TO_ARG = { + "AutoContrast": None, + "Equalize": None, + "Invert": None, + "Rotate": _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + "Posterize": _posterize_level_to_arg, + "PosterizeIncreasing": _posterize_increasing_level_to_arg, + "PosterizeOriginal": _posterize_original_level_to_arg, + "Solarize": _solarize_level_to_arg, + "SolarizeIncreasing": _solarize_increasing_level_to_arg, + "SolarizeAdd": _solarize_add_level_to_arg, + "Color": _enhance_level_to_arg, + "ColorIncreasing": _enhance_increasing_level_to_arg, + "Contrast": _enhance_level_to_arg, + "ContrastIncreasing": _enhance_increasing_level_to_arg, + "Brightness": _enhance_level_to_arg, + "BrightnessIncreasing": _enhance_increasing_level_to_arg, + "Sharpness": _enhance_level_to_arg, + "SharpnessIncreasing": _enhance_increasing_level_to_arg, + "ShearX": _shear_level_to_arg, + "ShearY": _shear_level_to_arg, + "TranslateX": _translate_abs_level_to_arg, + "TranslateY": _translate_abs_level_to_arg, + "TranslateXRel": _translate_rel_level_to_arg, + "TranslateYRel": _translate_rel_level_to_arg, +} + +NAME_TO_OP = { + "AutoContrast": auto_contrast, + "Equalize": equalize, + "Invert": invert, + "Rotate": rotate, + "Posterize": posterize, + "PosterizeIncreasing": posterize, + "PosterizeOriginal": posterize, + "Solarize": solarize, + "SolarizeIncreasing": solarize, + "SolarizeAdd": solarize_add, + "Color": color, + "ColorIncreasing": color, + "Contrast": contrast, + "ContrastIncreasing": contrast, + "Brightness": brightness, + "BrightnessIncreasing": brightness, + "Sharpness": sharpness, + "SharpnessIncreasing": sharpness, + "ShearX": shear_x, + "ShearY": shear_y, + "TranslateX": translate_x_abs, + "TranslateY": translate_y_abs, + "TranslateXRel": translate_x_rel, + "TranslateYRel": translate_y_rel, +} + + +class AugmentOp: + """ + Apply for video. + """ + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = { + "fillcolor": + hparams["img_mean"] if "img_mean" in hparams else _FILL, + "resample": + hparams["interpolation"] + if "interpolation" in hparams else _RANDOM_INTERPOLATION, + } + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_std = self.hparams.get("magnitude_std", 0) + + def __call__(self, img_list): + if self.prob < 1.0 and random.random() > self.prob: + return img_list + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = ( + self.level_fn(magnitude, self.hparams) + if self.level_fn is not None else ()) + + if isinstance(img_list, list): + return [ + self.aug_fn(img, *level_args, **self.kwargs) + for img in img_list + ] + else: + return self.aug_fn(img_list, *level_args, **self.kwargs) + + +_RAND_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "Posterize", + "Solarize", + "SolarizeAdd", + "Color", + "Contrast", + "Brightness", + "Sharpness", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + +_RAND_INCREASING_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "PosterizeIncreasing", + "SolarizeIncreasing", + "SolarizeAdd", + "ColorIncreasing", + "ContrastIncreasing", + "BrightnessIncreasing", + "SharpnessIncreasing", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + "Rotate": 0.3, + "ShearX": 0.2, + "ShearY": 0.2, + "TranslateXRel": 0.1, + "TranslateYRel": 0.1, + "Color": 0.025, + "Sharpness": 0.025, + "AutoContrast": 0.025, + "Solarize": 0.005, + "SolarizeAdd": 0.005, + "Contrast": 0.005, + "Brightness": 0.005, + "Equalize": 0.005, + "Posterize": 0, + "Invert": 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [ + AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) + for name in transforms + ] + + +class RandAugment: + + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, + self.num_layers, + replace=self.choice_weights is None, + p=self.choice_weights, + ) + for op in ops: + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams): + """ + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + + Create a RandAugment transform + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS + config = config_str.split("-") + assert config[0] == "rand" + config = config[1:] + for c in config: + cs = re.split(r"(\d.*)", c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == "mstd": + # noise param injected via hparams for now + hparams.setdefault("magnitude_std", float(val)) + elif key == "inc": + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS + elif key == "m": + magnitude = int(val) + elif key == "n": + num_layers = int(val) + elif key == "w": + weight_idx = int(val) + else: + assert NotImplementedError + ra_ops = rand_augment_ops( + magnitude=magnitude, hparams=hparams, transforms=transforms) + choice_weights = (None if weight_idx is None else + _select_rand_weights(weight_idx)) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/v_cls/random_erasing.py b/v_cls/random_erasing.py new file mode 100644 index 0000000000000000000000000000000000000000..73c10742a51f1f38c1f665283747f2629c3fcb00 --- /dev/null +++ b/v_cls/random_erasing.py @@ -0,0 +1,175 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py +pulished under an Apache License 2.0. + +COMMENT FROM ORIGINAL: +Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 +Copyright Zhun Zhong & Liang Zheng +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +import random + +import torch + + +def _get_pixels(per_pixel, + rand_color, + patch_size, + dtype=torch.float32, + device="cuda"): + # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() + # paths, flip the order so normal is run on CPU if this becomes a problem + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + if per_pixel: + return torch.empty(patch_size, dtype=dtype, device=device).normal_() + elif rand_color: + return torch.empty((patch_size[0], 1, 1), dtype=dtype, + device=device).normal_() + else: + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) + + +class RandomErasing: + """Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + This variant of RandomErasing is intended to be applied to either a batch + or single image tensor after it has been normalized by dataset mean and std. + Args: + probability: Probability that the Random Erasing operation will be performed. + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. + min_aspect: Minimum aspect ratio of erased area. + mode: pixel color mode, one of 'const', 'rand', or 'pixel' + 'const' - erase block is constant color of 0 for all channels + 'rand' - erase block is same per-channel random (normal) color + 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. + """ + + def __init__( + self, + probability=0.5, + min_area=0.02, + max_area=1 / 3, + min_aspect=0.3, + max_aspect=None, + mode="const", + min_count=1, + max_count=None, + num_splits=0, + device="cuda", + cube=True, + ): + self.probability = probability + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.min_count = min_count + self.max_count = max_count or min_count + self.num_splits = num_splits + mode = mode.lower() + self.rand_color = False + self.per_pixel = False + self.cube = cube + if mode == "rand": + self.rand_color = True # per block random normal + elif mode == "pixel": + self.per_pixel = True # per pixel random normal + else: + assert not mode or mode == "const" + self.device = device + + def _erase(self, img, chan, img_h, img_w, dtype): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count if self.min_count == self.max_count else + random.randint(self.min_count, self.max_count)) + for _ in range(count): + for _ in range(10): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / + count) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def _erase_cube( + self, + img, + batch_start, + batch_size, + chan, + img_h, + img_w, + dtype, + ): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count if self.min_count == self.max_count else + random.randint(self.min_count, self.max_count)) + for _ in range(count): + for _ in range(100): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / + count) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + for i in range(batch_start, batch_size): + img_instance = img[i] + img_instance[:, top:top + h, + left:left + w] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def __call__(self, input): + if len(input.size()) == 3: + self._erase(input, *input.size(), input.dtype) + else: + batch_size, chan, img_h, img_w = input.size() + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = ( + batch_size // self.num_splits if self.num_splits > 1 else 0) + if self.cube: + self._erase_cube( + input, + batch_start, + batch_size, + chan, + img_h, + img_w, + input.dtype, + ) + else: + for i in range(batch_start, batch_size): + self._erase(input[i], chan, img_h, img_w, input.dtype) + return input diff --git a/v_cls/transforms.py b/v_cls/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..530b2544d03f9e90a3e1b2f75154037678b52bb0 --- /dev/null +++ b/v_cls/transforms.py @@ -0,0 +1,586 @@ +# -------------------------------------------------------- +# Based on BEiT, timm, DINO and DeiT code bases +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/facebookresearch/deit +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import math +import numbers +import random +import warnings + +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional as F +from PIL import Image, ImageOps + + +class ToNumpy: + + def __call__(self, pil_img): + np_img = np.array(pil_img, dtype=np.uint8) + if np_img.ndim < 3: + np_img = np.expand_dims(np_img, axis=-1) + np_img = np.rollaxis(np_img, 2) # HWC to CHW + return np_img + + +class ToTensor: + + def __init__(self, dtype=torch.float32): + self.dtype = dtype + + def __call__(self, pil_img): + np_img = np.array(pil_img, dtype=np.uint8) + if np_img.ndim < 3: + np_img = np.expand_dims(np_img, axis=-1) + np_img = np.rollaxis(np_img, 2) # HWC to CHW + return torch.from_numpy(np_img).to(dtype=self.dtype) + + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + +def _pil_interp(method): + if method == 'bicubic': + return Image.BICUBIC + elif method == 'lanczos': + return Image.LANCZOS + elif method == 'hamming': + return Image.HAMMING + else: + # default bilinear, do we want to allow nearest? + return Image.BILINEAR + + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +class RandomResizedCropAndInterpolationWithTwoPic: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, + size, + second_size=None, + scale=(0.08, 1.0), + ratio=(3. / 4., 4. / 3.), + interpolation='bilinear', + second_interpolation='lanczos'): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if second_size is not None: + if isinstance(second_size, tuple): + self.second_size = second_size + else: + self.second_size = (second_size, second_size) + else: + self.second_size = None + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = _pil_interp(interpolation) + self.second_interpolation = _pil_interp(second_interpolation) + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for attempt in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + if self.second_size is None: + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + else: + return F.resized_crop(img, i, j, h, w, self.size, + interpolation), F.resized_crop( + img, i, j, h, w, self.second_size, + self.second_interpolation) + + def __repr__(self): + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = ' '.join( + [_pil_interpolation_to_str[x] for x in self.interpolation]) + else: + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format( + tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format( + tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0}'.format(interpolate_str) + if self.second_size is not None: + format_string += ', second_size={0}'.format(self.second_size) + format_string += ', second_interpolation={0}'.format( + _pil_interpolation_to_str[self.second_interpolation]) + format_string += ')' + return format_string + + +class GroupRandomCrop(object): + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, img_tuple): + img_group, label = img_tuple + + w, h = img_group[0].size + th, tw = self.size + + out_images = list() + + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + + for img in img_group: + assert (img.size[0] == w and img.size[1] == h) + if w == tw and h == th: + out_images.append(img) + else: + out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) + + return (out_images, label) + + +class GroupCenterCrop(object): + + def __init__(self, size): + self.worker = torchvision.transforms.CenterCrop(size) + + def __call__(self, img_tuple): + img_group, label = img_tuple + return ([self.worker(img) for img in img_group], label) + + +class GroupRandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + + def __init__(self, selective_flip=True, is_flow=False): + self.is_flow = is_flow + self.class_LeftRight = [86, 87, 93, 94, 166, 167 + ] if selective_flip else [] + + def __call__(self, img_tuple, is_flow=False): + img_group, label = img_tuple + v = random.random() + if (label not in self.class_LeftRight) and v < 0.5: + ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] + if self.is_flow: + for i in range(0, len(ret), 2): + ret[i] = ImageOps.invert( + ret[i]) # invert flow pixel values when flipping + return (ret, label) + else: + return img_tuple + + +class GroupNormalize(object): + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor_tuple): + tensor, label = tensor_tuple + rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) + rep_std = self.std * (tensor.size()[0] // len(self.std)) + + # TODO: make efficient + for t, m, s in zip(tensor, rep_mean, rep_std): + t.sub_(m).div_(s) + + return (tensor, label) + + +class GroupGrayScale(object): + + def __init__(self, size): + self.worker = torchvision.transforms.Grayscale(size) + + def __call__(self, img_tuple): + img_group, label = img_tuple + return ([self.worker(img) for img in img_group], label) + + +class GroupScale(object): + """ Rescales the input PIL.Image to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + self.worker = torchvision.transforms.Resize(size, interpolation) + + def __call__(self, img_tuple): + img_group, label = img_tuple + return ([self.worker(img) for img in img_group], label) + + +class GroupOverSample(object): + + def __init__(self, crop_size, scale_size=None): + self.crop_size = crop_size if not isinstance(crop_size, int) else ( + crop_size, crop_size) + + if scale_size is not None: + self.scale_worker = GroupScale(scale_size) + else: + self.scale_worker = None + + def __call__(self, img_tuple): + if self.scale_worker is not None: + img_tuple = self.scale_worker(img_tuple) + + img_group, label = img_tuple + + image_w, image_h = img_group[0].size + crop_w, crop_h = self.crop_size + + offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, + crop_w, crop_h) + oversample_group = list() + for o_w, o_h in offsets: + normal_group = list() + flip_group = list() + for i, img in enumerate(img_group): + crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) + normal_group.append(crop) + flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) + + if img.mode == 'L' and i % 2 == 0: + flip_group.append(ImageOps.invert(flip_crop)) + else: + flip_group.append(flip_crop) + + oversample_group.extend(normal_group) + oversample_group.extend(flip_group) + return (oversample_group, label) + + +class GroupFullResSample(object): + + def __init__(self, crop_size, scale_size=None, flip=True): + self.crop_size = crop_size if not isinstance(crop_size, int) else ( + crop_size, crop_size) + + if scale_size is not None: + self.scale_worker = GroupScale(scale_size) + else: + self.scale_worker = None + self.flip = flip + + def __call__(self, img_tuple): + + if self.scale_worker is not None: + img_tuple = self.scale_worker(img_tuple) + + img_group, label = img_tuple + image_w, image_h = img_group[0].size + crop_w, crop_h = self.crop_size + + w_step = (image_w - crop_w) // 4 + h_step = (image_h - crop_h) // 4 + + offsets = list() + offsets.append((0 * w_step, 2 * h_step)) # left + offsets.append((4 * w_step, 2 * h_step)) # right + offsets.append((2 * w_step, 2 * h_step)) # center + + oversample_group = list() + for o_w, o_h in offsets: + normal_group = list() + flip_group = list() + for i, img in enumerate(img_group): + crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) + normal_group.append(crop) + if self.flip: + flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) + + if img.mode == 'L' and i % 2 == 0: + flip_group.append(ImageOps.invert(flip_crop)) + else: + flip_group.append(flip_crop) + + oversample_group.extend(normal_group) + oversample_group.extend(flip_group) + return (oversample_group, label) + + +class GroupMultiScaleCrop(object): + + def __init__(self, + input_size, + scales=None, + max_distort=1, + fix_crop=True, + more_fix_crop=True): + self.scales = scales if scales is not None else [1, .875, .75, .66] + self.max_distort = max_distort + self.fix_crop = fix_crop + self.more_fix_crop = more_fix_crop + self.input_size = input_size if not isinstance(input_size, int) else [ + input_size, input_size + ] + self.interpolation = Image.BILINEAR + + def __call__(self, img_tuple): + img_group, label = img_tuple + + im_size = img_group[0].size + + crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) + crop_img_group = [ + img.crop( + (offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) + for img in img_group + ] + ret_img_group = [ + img.resize((self.input_size[0], self.input_size[1]), + self.interpolation) for img in crop_img_group + ] + return (ret_img_group, label) + + def _sample_crop_size(self, im_size): + image_w, image_h = im_size[0], im_size[1] + + # find a crop size + base_size = min(image_w, image_h) + crop_sizes = [int(base_size * x) for x in self.scales] + crop_h = [ + self.input_size[1] if abs(x - self.input_size[1]) < 3 else x + for x in crop_sizes + ] + crop_w = [ + self.input_size[0] if abs(x - self.input_size[0]) < 3 else x + for x in crop_sizes + ] + + pairs = [] + for i, h in enumerate(crop_h): + for j, w in enumerate(crop_w): + if abs(i - j) <= self.max_distort: + pairs.append((w, h)) + + crop_pair = random.choice(pairs) + if not self.fix_crop: + w_offset = random.randint(0, image_w - crop_pair[0]) + h_offset = random.randint(0, image_h - crop_pair[1]) + else: + w_offset, h_offset = self._sample_fix_offset( + image_w, image_h, crop_pair[0], crop_pair[1]) + + return crop_pair[0], crop_pair[1], w_offset, h_offset + + def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): + offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, + crop_w, crop_h) + return random.choice(offsets) + + @staticmethod + def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): + w_step = (image_w - crop_w) // 4 + h_step = (image_h - crop_h) // 4 + + ret = list() + ret.append((0, 0)) # upper left + ret.append((4 * w_step, 0)) # upper right + ret.append((0, 4 * h_step)) # lower left + ret.append((4 * w_step, 4 * h_step)) # lower right + ret.append((2 * w_step, 2 * h_step)) # center + + if more_fix_crop: + ret.append((0, 2 * h_step)) # center left + ret.append((4 * w_step, 2 * h_step)) # center right + ret.append((2 * w_step, 4 * h_step)) # lower center + ret.append((2 * w_step, 0 * h_step)) # upper center + + ret.append((1 * w_step, 1 * h_step)) # upper left quarter + ret.append((3 * w_step, 1 * h_step)) # upper right quarter + ret.append((1 * w_step, 3 * h_step)) # lower left quarter + ret.append((3 * w_step, 3 * h_step)) # lower righ quarter + + return ret + + +class GroupRandomSizedCrop(object): + """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size + and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio + This is popularly used to train the Inception networks + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, img_tuple): + img_group, label = img_tuple + + for attempt in range(10): + area = img_group[0].size[0] * img_group[0].size[1] + target_area = random.uniform(0.08, 1.0) * area + aspect_ratio = random.uniform(3. / 4, 4. / 3) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if random.random() < 0.5: + w, h = h, w + + if w <= img_group[0].size[0] and h <= img_group[0].size[1]: + x1 = random.randint(0, img_group[0].size[0] - w) + y1 = random.randint(0, img_group[0].size[1] - h) + found = True + break + else: + found = False + x1 = 0 + y1 = 0 + + if found: + out_group = list() + for img in img_group: + img = img.crop((x1, y1, x1 + w, y1 + h)) + assert (img.size == (w, h)) + out_group.append( + img.resize((self.size, self.size), self.interpolation)) + return out_group + else: + # Fallback + scale = GroupScale(self.size, interpolation=self.interpolation) + crop = GroupRandomCrop(self.size) + return crop(scale(img_group)) + + +class Stack(object): + + def __init__(self, roll=False): + self.roll = roll + + def __call__(self, img_tuple): + img_group, label = img_tuple + + if img_group[0].mode == 'L': + return (np.concatenate([np.expand_dims(x, 2) for x in img_group], + axis=2), label) + elif img_group[0].mode == 'RGB': + if self.roll: + return (np.concatenate( + [np.array(x)[:, :, ::-1] for x in img_group], + axis=2), label) + else: + return (np.concatenate(img_group, axis=2), label) + + +class ToTorchFormatTensor(object): + """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] + to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ + + def __init__(self, div=True): + self.div = div + + def __call__(self, pic_tuple): + pic, label = pic_tuple + + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() + else: + # handle PIL Image + img = torch.as_tensor(pic.tobytes(), dtype=torch.uint8) + img = img.view(pic.size[1], pic.size[0], len(pic.mode)) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + return (img.float().div(255.) if self.div else img.float(), label) + + +class IdentityTransform(object): + + def __call__(self, data): + return data diff --git a/v_cls/video_transforms.py b/v_cls/video_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc045e39f214315f5d754c7c7aaeb7524f06b4d --- /dev/null +++ b/v_cls/video_transforms.py @@ -0,0 +1,1267 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import math +import numbers +import random + +import numpy as np +import PIL +import torch +import torchvision +import torchvision.transforms.functional as F +from PIL import Image +from torchvision import transforms + +from . import functional as FF +from .rand_augment import rand_augment_transform +from .random_erasing import RandomErasing + +_pil_interpolation_to_str = { + Image.NEAREST: "PIL.Image.NEAREST", + Image.BILINEAR: "PIL.Image.BILINEAR", + Image.BICUBIC: "PIL.Image.BICUBIC", + Image.LANCZOS: "PIL.Image.LANCZOS", + Image.HAMMING: "PIL.Image.HAMMING", + Image.BOX: "PIL.Image.BOX", +} + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _pil_interp(method): + if method == "bicubic": + return Image.BICUBIC + elif method == "lanczos": + return Image.LANCZOS + elif method == "hamming": + return Image.HAMMING + else: + return Image.BILINEAR + + +def random_short_side_scale_jitter(images, + min_size, + max_size, + boxes=None, + inverse_uniform_sampling=False): + """ + Perform a spatial short scale jittering on the given images and + corresponding boxes. + Args: + images (tensor): images to perform scale jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + min_size (int): the minimal size to scale the frames. + max_size (int): the maximal size to scale the frames. + boxes (ndarray): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + inverse_uniform_sampling (bool): if True, sample uniformly in + [1 / max_scale, 1 / min_scale] and take a reciprocal to get the + scale. If False, take a uniform sample from [min_scale, max_scale]. + Returns: + (tensor): the scaled images with dimension of + `num frames` x `channel` x `new height` x `new width`. + (ndarray or None): the scaled boxes with dimension of + `num boxes` x 4. + """ + if inverse_uniform_sampling: + size = int( + round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))) + else: + size = int(round(np.random.uniform(min_size, max_size))) + + height = images.shape[2] + width = images.shape[3] + if (width <= height and width == size) or (height <= width + and height == size): + return images, boxes + new_width = size + new_height = size + if width < height: + new_height = int(math.floor((float(height) / width) * size)) + if boxes is not None: + boxes = boxes * float(new_height) / height + else: + new_width = int(math.floor((float(width) / height) * size)) + if boxes is not None: + boxes = boxes * float(new_width) / width + + return ( + torch.nn.functional.interpolate( + images, + size=(new_height, new_width), + mode="bilinear", + align_corners=False, + ), + boxes, + ) + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Peform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to peform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def random_crop(images, size, boxes=None): + """ + Perform random spatial crop on the given images and corresponding boxes. + Args: + images (tensor): images to perform random crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): the size of height and width to crop on the image. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + cropped (tensor): cropped images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + if images.shape[2] == size and images.shape[3] == size: + return images + height = images.shape[2] + width = images.shape[3] + y_offset = 0 + if height > size: + y_offset = int(np.random.randint(0, height - size)) + x_offset = 0 + if width > size: + x_offset = int(np.random.randint(0, width - size)) + cropped = images[:, :, y_offset:y_offset + size, x_offset:x_offset + size] + + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None) + + return cropped, cropped_boxes + + +def horizontal_flip(prob, images, boxes=None): + """ + Perform horizontal flip on the given images and corresponding boxes. + Args: + prob (float): probility to flip the images. + images (tensor): images to perform horizontal flip, the dimension is + `num frames` x `channel` x `height` x `width`. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + images (tensor): images with dimension of + `num frames` x `channel` x `height` x `width`. + flipped_boxes (ndarray or None): the flipped boxes with dimension of + `num boxes` x 4. + """ + if boxes is None: + flipped_boxes = None + else: + flipped_boxes = boxes.copy() + + if np.random.uniform() < prob: + images = images.flip((-1)) + + if len(images.shape) == 3: + width = images.shape[2] + elif len(images.shape) == 4: + width = images.shape[3] + else: + raise NotImplementedError("Dimension does not supported") + if boxes is not None: + flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1 + + return images, flipped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[:, :, y_offset:y_offset + size, x_offset:x_offset + size] + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None) + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +def clip_boxes_to_image(boxes, height, width): + """ + Clip an array of boxes to an image with the given height and width. + Args: + boxes (ndarray): bounding boxes to perform clipping. + Dimension is `num boxes` x 4. + height (int): given image height. + width (int): given image width. + Returns: + clipped_boxes (ndarray): the clipped boxes with dimension of + `num boxes` x 4. + """ + clipped_boxes = boxes.copy() + clipped_boxes[:, [0, 2]] = np.minimum(width - 1.0, + np.maximum(0.0, boxes[:, [0, 2]])) + clipped_boxes[:, [1, 3]] = np.minimum(height - 1.0, + np.maximum(0.0, boxes[:, [1, 3]])) + return clipped_boxes + + +def blend(images1, images2, alpha): + """ + Blend two images with a given weight alpha. + Args: + images1 (tensor): the first images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + images2 (tensor): the second images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + alpha (float): the blending weight. + Returns: + (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + return images1 * alpha + images2 * (1 - alpha) + + +def grayscale(images): + """ + Get the grayscale for the input images. The channels of images should be + in order BGR. + Args: + images (tensor): the input images for getting grayscale. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + img_gray (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + # R -> 0.299, G -> 0.587, B -> 0.114. + img_gray = torch.tensor(images) + gray_channel = (0.299 * images[:, 2] + 0.587 * images[:, 1] + + 0.114 * images[:, 0]) + img_gray[:, 0] = gray_channel + img_gray[:, 1] = gray_channel + img_gray[:, 2] = gray_channel + return img_gray + + +def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): + """ + Perfrom a color jittering on the input images. The channels of images + should be in order BGR. + Args: + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + img_brightness (float): jitter ratio for brightness. + img_contrast (float): jitter ratio for contrast. + img_saturation (float): jitter ratio for saturation. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + + jitter = [] + if img_brightness != 0: + jitter.append("brightness") + if img_contrast != 0: + jitter.append("contrast") + if img_saturation != 0: + jitter.append("saturation") + + if len(jitter) > 0: + order = np.random.permutation(np.arange(len(jitter))) + for idx in range(0, len(jitter)): + if jitter[order[idx]] == "brightness": + images = brightness_jitter(img_brightness, images) + elif jitter[order[idx]] == "contrast": + images = contrast_jitter(img_contrast, images) + elif jitter[order[idx]] == "saturation": + images = saturation_jitter(img_saturation, images) + return images + + +def brightness_jitter(var, images): + """ + Perfrom brightness jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for brightness. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_bright = torch.zeros(images.shape) + images = blend(images, img_bright, alpha) + return images + + +def contrast_jitter(var, images): + """ + Perfrom contrast jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for contrast. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_gray = grayscale(images) + img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) + images = blend(images, img_gray, alpha) + return images + + +def saturation_jitter(var, images): + """ + Perfrom saturation jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for saturation. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + img_gray = grayscale(images) + images = blend(images, img_gray, alpha) + + return images + + +def lighting_jitter(images, alphastd, eigval, eigvec): + """ + Perform AlexNet-style PCA jitter on the given images. + Args: + images (tensor): images to perform lighting jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + alphastd (float): jitter ratio for PCA jitter. + eigval (list): eigenvalues for PCA jitter. + eigvec (list[list]): eigenvectors for PCA jitter. + Returns: + out_images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if alphastd == 0: + return images + # generate alpha1, alpha2, alpha3. + alpha = np.random.normal(0, alphastd, size=(1, 3)) + eig_vec = np.array(eigvec) + eig_val = np.reshape(eigval, (1, 3)) + rgb = np.sum( + eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), + axis=1, + ) + out_images = torch.zeros_like(images) + if len(images.shape) == 3: + # C H W + channel_dim = 0 + elif len(images.shape) == 4: + # T C H W + channel_dim = 1 + else: + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") + + for idx in range(images.shape[channel_dim]): + # C H W + if len(images.shape) == 3: + out_images[idx] = images[idx] + rgb[2 - idx] + # T C H W + elif len(images.shape) == 4: + out_images[:, idx] = images[:, idx] + rgb[2 - idx] + else: + raise NotImplementedError( + f"Unsupported dimension {len(images.shape)}") + + return out_images + + +def color_normalization(images, mean, stddev): + """ + Perform color nomration on the given images. + Args: + images (tensor): images to perform color normalization. Dimension is + `num frames` x `channel` x `height` x `width`. + mean (list): mean values for normalization. + stddev (list): standard deviations for normalization. + + Returns: + out_images (tensor): the noramlized images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if len(images.shape) == 3: + assert ( + len(mean) == images.shape[0]), "channel mean not computed properly" + assert (len(stddev) == images.shape[0] + ), "channel stddev not computed properly" + elif len(images.shape) == 4: + assert ( + len(mean) == images.shape[1]), "channel mean not computed properly" + assert (len(stddev) == images.shape[1] + ), "channel stddev not computed properly" + else: + raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") + + out_images = torch.zeros_like(images) + for idx in range(len(mean)): + # C H W + if len(images.shape) == 3: + out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] + elif len(images.shape) == 4: + out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] + else: + raise NotImplementedError( + f"Unsupported dimension {len(images.shape)}") + return out_images + + +def _get_param_spatial_crop(scale, + ratio, + height, + width, + num_repeat=10, + log_scale=True, + switch_hw=False): + """ + Given scale, ratio, height and width, return sampled coordinates of the videos. + """ + for _ in range(num_repeat): + area = height * width + target_area = random.uniform(*scale) * area + if log_scale: + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + else: + aspect_ratio = random.uniform(*ratio) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if np.random.uniform() < 0.5 and switch_hw: + w, h = h, w + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(ratio): + w = width + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + +def random_resized_crop( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + Crop the given images to random size and aspect ratio. A crop of random + size (default: of 0.08 to 1.0) of the original size and a random aspect + ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This + crop is finally resized to given size. This is popularly used to train the + Inception networks. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + cropped = images[:, :, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped, + size=(target_height, target_width), + mode="bilinear", + align_corners=False, + ) + + +def random_resized_crop_with_shift( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + This is similar to random_resized_crop. However, it samples two different + boxes (for cropping) for the first and last frame. It then linearly + interpolates the two boxes for other frames. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + t = images.shape[1] + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) + i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] + j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] + h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] + w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] + out = torch.zeros((3, t, target_height, target_width)) + for ind in range(t): + out[:, ind:ind + 1, :, :] = torch.nn.functional.interpolate( + images[:, ind:ind + 1, i_s[ind]:i_s[ind] + h_s[ind], + j_s[ind]:j_s[ind] + w_s[ind], ], + size=(target_height, target_width), + mode="bilinear", + align_corners=False, + ) + return out + + +def create_random_augment( + input_size, + auto_augment=None, + interpolation="bilinear", +): + """ + Get video randaug transform. + + Args: + input_size: The size of the input video in tuple. + auto_augment: Parameters for randaug. An example: + "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number + of operations to apply). + interpolation: Interpolation method. + """ + if isinstance(input_size, tuple): + img_size = input_size[-2:] + else: + img_size = input_size + + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = {"translate_const": int(img_size_min * 0.45)} + if interpolation and interpolation != "random": + aa_params["interpolation"] = _pil_interp(interpolation) + if auto_augment.startswith("rand"): + return transforms.Compose( + [rand_augment_transform(auto_augment, aa_params)]) + raise NotImplementedError + + +def random_sized_crop_img( + im, + size, + jitter_scale=(0.08, 1.0), + jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), + max_iter=10, +): + """ + Performs Inception-style cropping (used for training). + """ + assert (len( + im.shape) == 3), "Currently only support image for random_sized_crop" + h, w = im.shape[1:3] + i, j, h, w = _get_param_spatial_crop( + scale=jitter_scale, + ratio=jitter_aspect, + height=h, + width=w, + num_repeat=max_iter, + log_scale=False, + switch_hw=True, + ) + cropped = im[:, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped.unsqueeze(0), + size=(size, size), + mode="bilinear", + align_corners=False, + ).squeeze(0) + + +# The following code are modified based on timm lib, we will replace the following +# contents with dependency from PyTorchVideo. +# https://github.com/facebookresearch/pytorchvideo +class RandomResizedCropAndInterpolation: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation="bilinear", + ): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + print("range should be of kind (min, max)") + + if interpolation == "random": + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = _pil_interp(interpolation) + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for _ in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + + def __repr__(self): + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = " ".join( + [_pil_interpolation_to_str[x] for x in self.interpolation]) + else: + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + "(size={0}".format(self.size) + format_string += ", scale={0}".format( + tuple(round(s, 4) for s in self.scale)) + format_string += ", ratio={0}".format( + tuple(round(r, 4) for r in self.ratio)) + format_string += ", interpolation={0})".format(interpolate_str) + return format_string + + +def transforms_imagenet_train( + img_size=224, + scale=None, + ratio=None, + hflip=0.5, + vflip=0.0, + color_jitter=0.4, + auto_augment=None, + interpolation="random", + use_prefetcher=False, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + re_prob=0.0, + re_mode="const", + re_count=1, + re_num_splits=0, + separate=False, +): + """ + If separate==True, the transforms are returned as a tuple of 3 separate transforms + for use in a mixing dataset that passes + * all data through the first (primary) transform, called the 'clean' data + * a portion of the data through the secondary transform + * normalizes and converts the branches above with the third, final transform + """ + if isinstance(img_size, tuple): + img_size = img_size[-2:] + else: + img_size = img_size + + scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range + ratio = tuple(ratio + or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range + primary_tfl = [ + RandomResizedCropAndInterpolation( + img_size, scale=scale, ratio=ratio, interpolation=interpolation) + ] + if hflip > 0.0: + primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] + if vflip > 0.0: + primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] + + secondary_tfl = [] + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = dict( + translate_const=int(img_size_min * 0.45), + img_mean=tuple([min(255, round(255 * x)) for x in mean]), + ) + if interpolation and interpolation != "random": + aa_params["interpolation"] = _pil_interp(interpolation) + if auto_augment.startswith("rand"): + secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] + elif auto_augment.startswith("augmix"): + raise NotImplementedError("Augmix not implemented") + else: + raise NotImplementedError("Auto aug not implemented") + elif color_jitter is not None: + # color jitter is enabled when not using AA + if isinstance(color_jitter, (list, tuple)): + # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue + color_jitter = (float(color_jitter), ) * 3 + secondary_tfl += [transforms.ColorJitter(*color_jitter)] + + final_tfl = [] + final_tfl += [ + transforms.ToTensor(), + transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), + ] + if re_prob > 0.0: + final_tfl.append( + RandomErasing( + re_prob, + mode=re_mode, + max_count=re_count, + num_splits=re_num_splits, + device="cpu", + cube=False, + )) + + if separate: + return ( + transforms.Compose(primary_tfl), + transforms.Compose(secondary_tfl), + transforms.Compose(final_tfl), + ) + else: + return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) + + +############################################################################################################ +############################################################################################################ + + +class Compose(object): + """Composes several transforms + Args: + transforms (list of ``Transform`` objects): list of transforms + to compose + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, clip): + for t in self.transforms: + clip = t(clip) + return clip + + +class RandomHorizontalFlip(object): + """Horizontally flip the list of given images randomly + with a probability 0.5 + """ + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Randomly flipped clip + """ + if random.random() < 0.5: + if isinstance(clip[0], np.ndarray): + return [np.fliplr(img) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + return [ + img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + ' but got list of {0}'.format(type(clip[0]))) + return clip + + +class RandomResize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): + self.ratio = ratio + self.interpolation = interpolation + + def __call__(self, clip): + scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) + + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + + new_w = int(im_w * scaling_factor) + new_h = int(im_h * scaling_factor) + new_size = (new_w, new_h) + resized = FF.resize_clip( + clip, new_size, interpolation=self.interpolation) + return resized + + +class Resize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, size, interpolation='nearest'): + self.size = size + self.interpolation = interpolation + + def __call__(self, clip): + resized = FF.resize_clip( + clip, self.size, interpolation=self.interpolation) + return resized + + +class RandomCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = random.randint(0, im_w - w) + y1 = random.randint(0, im_h - h) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ThreeCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w != im_w and h != im_h: + clip = FF.resize_clip(clip, self.size, interpolation="bilinear") + im_h, im_w, im_c = clip[0].shape + + step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0) + cropped = [] + for i in range(3): + if (im_h > self.size[0]): + x1 = 0 + y1 = i * step + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + else: + x1 = i * step + y1 = 0 + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + return cropped + + +class RandomRotation(object): + """Rotate entire clip randomly by a random angle within + given bounds + Args: + degrees (sequence or int): Range of degrees to select from + If degrees is a number instead of sequence like (min, max), + the range of degrees, will be (-degrees, +degrees). + """ + + def __init__(self, degrees): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError('If degrees is a single number,' + 'must be positive') + degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError('If degrees is a sequence,' + 'it must be of len 2.') + + self.degrees = degrees + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + import skimage + angle = random.uniform(self.degrees[0], self.degrees[1]) + if isinstance(clip[0], np.ndarray): + rotated = [skimage.transform.rotate(img, angle) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + rotated = [img.rotate(angle) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + return rotated + + +class CenterCrop(object): + """Extract center crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = int(round((im_w - w) / 2.)) + y1 = int(round((im_h - h) / 2.)) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation and hue of the clip + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + def get_params(self, brightness, contrast, saturation, hue): + if brightness > 0: + brightness_factor = random.uniform( + max(0, 1 - brightness), 1 + brightness) + else: + brightness_factor = None + + if contrast > 0: + contrast_factor = random.uniform( + max(0, 1 - contrast), 1 + contrast) + else: + contrast_factor = None + + if saturation > 0: + saturation_factor = random.uniform( + max(0, 1 - saturation), 1 + saturation) + else: + saturation_factor = None + + if hue > 0: + hue_factor = random.uniform(-hue, hue) + else: + hue_factor = None + return brightness_factor, contrast_factor, saturation_factor, hue_factor + + def __call__(self, clip): + """ + Args: + clip (list): list of PIL.Image + Returns: + list PIL.Image : list of transformed PIL.Image + """ + if isinstance(clip[0], np.ndarray): + raise TypeError( + 'Color jitter not yet implemented for numpy arrays') + elif isinstance(clip[0], PIL.Image.Image): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append( + lambda img: torchvision.transforms.functional. + adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append( + lambda img: torchvision.transforms.functional. + adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms. + functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append( + lambda img: torchvision.transforms.functional. + adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + + # Apply to all images + jittered_clip = [] + for img in clip: + for func in img_transforms: + jittered_img = func(img) + jittered_clip.append(jittered_img) + + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return jittered_clip + + +class Normalize(object): + """Normalize a clip with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + .. note:: + This transform acts out of place, i.e., it does not mutates the input tensor. + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, clip): + """ + Args: + clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor clip. + """ + return FF.normalize(clip, self.mean, self.std) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format( + self.mean, self.std) diff --git a/v_cls/volume_transforms.py b/v_cls/volume_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..80e002acd6d663e5b60096c7b2ba89b04e2eff9e --- /dev/null +++ b/v_cls/volume_transforms.py @@ -0,0 +1,131 @@ +import numpy as np +import torch +from PIL import Image + + +def convert_img(img): + """Converts (H, W, C) numpy.ndarray to (C, W, H) format + """ + if len(img.shape) == 3: + img = img.transpose(2, 0, 1) + if len(img.shape) == 2: + img = np.expand_dims(img, 0) + return img + + +class ClipToTensor(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( + ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image\ + but got list of {0}'.format(type(clip[0]))) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError('Expected numpy.ndarray or PIL.Image\ + but got list of {0}'.format(type(clip[0]))) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = np_clip / 255.0 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(tensor_clip, 255) + return tensor_clip + + +# Note this norms data to -1/1 +class ClipToTensor_K(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( + ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image\ + but got list of {0}'.format(type(clip[0]))) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError('Expected numpy.ndarray or PIL.Image\ + but got list of {0}'.format(type(clip[0]))) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = (np_clip - 127.5) / 127.5 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) + return tensor_clip + + +class ToTensor(object): + """Converts numpy array to tensor + """ + + def __call__(self, array): + tensor = torch.from_numpy(array) + return tensor diff --git a/v_cls/zero_shot.py b/v_cls/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..a956b7c19dae9bac3af841a827a715c356c0bc93 --- /dev/null +++ b/v_cls/zero_shot.py @@ -0,0 +1,109 @@ +import logging +import os + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from open_clip import get_input_dtype, get_tokenizer +from open_clip.factory import HF_HUB_PREFIX +from training.distributed import is_master +from v_cls.zero_shot_classifier import build_zero_shot_classifier +from v_cls.zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, CLASSNAMES + +from training.precision import get_autocast + + + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] + + +def run(model, classifier, dataloader, args): + autocast = get_autocast(args.precision) + input_dtype = get_input_dtype(args.precision) + file = os.path.join(args.output_dir, str(args.rank) + '.txt') + final_result = [] + with torch.no_grad(): + top1, top5, n = 0., 0., 0. + for batch in tqdm(dataloader, unit_scale=args.batch_size): + images = batch[0] + target = batch[1] + ids = batch[2] + chunk_nb = batch[3] + split_nb = batch[4] + images = images.to(device=args.device, dtype=input_dtype) + target = target.to(args.device) + + with autocast(): + # predict + output = model(image=images) + image_features = output['image_features'] if isinstance(output, dict) else output[0] + logits = 100. * image_features @ classifier + output = logits + # print(output.shape) + for i in range(output.size(0)): + string = "{} {} {} {} {}\n".format( + ids[i], str(output.data[i].cpu().numpy().tolist()), + str(int(target[i].cpu().numpy())), + str(int(chunk_nb[i].cpu().numpy())), + str(int(split_nb[i].cpu().numpy()))) + final_result.append(string) + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = (top1 / n) + top5 = (top5 / n) + + if not os.path.exists(file): + os.mknod(file) + with open(file, 'w') as f: + f.write("{}, {}\n".format(top1, top5)) + for line in final_result: + f.write(line) + + return top1, top5 + + +def zero_shot_eval(model, dataloader, epoch, args): + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + if args.distributed and not args.horovod: + model = model.module + if is_master(args): + logging.info(f'Starting zero-shot {args.val_v_cls_data[0].upper()}') + logging.info('Building zero-shot classifier') + autocast = get_autocast(args.precision) + with autocast(): + tokenizer = get_tokenizer(HF_HUB_PREFIX+args.model, cache_dir=args.cache_dir) + classifier = build_zero_shot_classifier( + model, + tokenizer=tokenizer, + classnames=CLASSNAMES[args.val_v_cls_data[0]], + templates=OPENAI_IMAGENET_TEMPLATES, + num_classes_per_batch=10, + device=args.device, + use_tqdm=True, + ) + + + if is_master(args): + logging.info('Using classifier') + # results = {} + run(model, classifier, dataloader, args) + # results['kinetics400-zeroshot-val-top1'] = top1 + # results['kinetics400-zeroshot-val-top5'] = top5 + + if is_master(args): + logging.info(f'Finished zero-shot {args.val_v_cls_data[0].upper()}.') + + # return results diff --git a/v_cls/zero_shot_classifier.py b/v_cls/zero_shot_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a5267cea4119994e30bb4830a6744cf25bdbaf --- /dev/null +++ b/v_cls/zero_shot_classifier.py @@ -0,0 +1,111 @@ +from functools import partial +from itertools import islice +from typing import Callable, List, Optional, Sequence, Union + +import torch +import torch.nn.functional as F + + +def batched(iterable, n): + """Batch data into lists of length *n*. The last batch may be shorter. + NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch + + +def build_zero_shot_classifier( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + num_classes_per_batch: Optional[int] = 10, + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names in batches + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + num_classes_per_batch: The number of classes to batch together in each forward, all if None + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + use_format = isinstance(templates[0], str) + num_templates = len(templates) + num_classes = len(classnames) + if use_tqdm: + import tqdm + num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) + iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) + else: + iter_wrap = iter + + def _process_batch(batch_classnames): + num_batch_classes = len(batch_classnames) + texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] + input_ids, attention_mask = tokenizer(texts) + input_ids, attention_mask = input_ids.to(device), attention_mask.to(device) + class_embeddings = F.normalize(model.encode_text(input_ids, attention_mask), dim=-1) + class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) + class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) + class_embeddings = class_embeddings.T + return class_embeddings + + with torch.no_grad(): + if num_classes_per_batch: + batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] + zeroshot_weights = torch.cat(batched_embeds, dim=1) + else: + zeroshot_weights = _process_batch(classnames) + return zeroshot_weights + + +def build_zero_shot_classifier_legacy( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names 1 by 1 + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + if use_tqdm: + import tqdm + iter_wrap = tqdm.tqdm + else: + iter_wrap = iter + + use_format = isinstance(templates[0], str) + + with torch.no_grad(): + zeroshot_weights = [] + for classname in iter_wrap(classnames): + texts = [template.format(classname) if use_format else template(classname) for template in templates] + texts = tokenizer(texts).to(device) # tokenize + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) + + return zeroshot_weights + diff --git a/v_cls/zero_shot_metadata.py b/v_cls/zero_shot_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..db255c6d04cf25a2dbc56cc987fc072c417e4959 --- /dev/null +++ b/v_cls/zero_shot_metadata.py @@ -0,0 +1,108 @@ +import os + +import pandas as pd + +OPENAI_IMAGENET_TEMPLATES = ( + lambda c: f'a bad video of a {c}.', + lambda c: f'a video of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a video of the hard to see {c}.', + lambda c: f'a low resolution video of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad video of the {c}.', + lambda c: f'a cropped video of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a video of a hard to see {c}.', + lambda c: f'a bright video of a {c}.', + lambda c: f'a video of a clean {c}.', + lambda c: f'a video of a dirty {c}.', + lambda c: f'a dark video of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a video of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a video of the cool {c}.', + lambda c: f'a close-up video of a {c}.', + lambda c: f'a black and white video of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated video of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright video of the {c}.', + lambda c: f'a cropped video of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a video of the dirty {c}.', + lambda c: f'a jpeg corrupted video of a {c}.', + lambda c: f'a blurry video of the {c}.', + lambda c: f'a video of the {c}.', + lambda c: f'a good video of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a video of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up video of the {c}.', + lambda c: f'a video of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution video of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a video of the clean {c}.', + lambda c: f'a video of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a video of a nice {c}.', + lambda c: f'a video of a weird {c}.', + lambda c: f'a blurry video of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated video of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted video of the {c}.', + lambda c: f'a good video of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a video of the nice {c}.', + lambda c: f'a video of the small {c}.', + lambda c: f'a video of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a video of the large {c}.', + lambda c: f'a black and white video of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark video of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a video of a cool {c}.', + lambda c: f'a video of a small {c}.', + lambda c: f'a tattoo of the {c}.', +) + + +# a much smaller subset of above prompts +# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +SIMPLE_IMAGENET_TEMPLATES = ( + lambda c: f'itap of a {c}.', + lambda c: f'a bad video of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a video of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a video of the small {c}.', +) + +PATH_k400 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "kinetics_400_labels.csv") +PATH_k600 = os.path.join(os.path.dirname(os.path.abspath(__file__)), "kinetics_600_labels.csv") +CLASSNAMES = { + 'Kinetics-400': tuple(pd.read_csv(PATH_k400).values[:, 1]), + 'Kinetics-600': tuple(pd.read_csv(PATH_k600).values[:, 1]), + +} + diff --git a/v_cls/zeroshot_cls.py b/v_cls/zeroshot_cls.py new file mode 100644 index 0000000000000000000000000000000000000000..e0caea0e209a6aba575ab46c1da15dc56da0e08a --- /dev/null +++ b/v_cls/zeroshot_cls.py @@ -0,0 +1,136 @@ +import json +import logging +import os + +import numpy as np +import torch +from scipy.special import softmax +from training.distributed import is_master +from .zero_shot import zero_shot_eval + + +def compute_video(lst): + i, video_id, data, label = lst + feat = [x for x in data] + feat = np.mean(feat, axis=0) + pred = np.argmax(feat) + top1 = (int(pred) == int(label)) * 1.0 + top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0 + return [pred, top1, top5, int(label)] + +def merge(eval_path, num_tasks, method='prob'): + assert method in ['prob', 'score'] + dict_feats = {} + dict_label = {} + dict_pos = {} + # logging.info("Reading individual output files") + + for x in range(num_tasks): + file = os.path.join(eval_path, str(x) + '.txt') + lines = open(file, 'r').readlines()[1:] + for line in lines: + line = line.strip() + name = line.split('[')[0] + label = line.split(']')[1].split(' ')[1] + chunk_nb = line.split(']')[1].split(' ')[2] + split_nb = line.split(']')[1].split(' ')[3] + data = np.fromstring( + line.split('[')[1].split(']')[0], dtype=np.float, sep=',') + if name not in dict_feats: + dict_feats[name] = [] + dict_label[name] = 0 + dict_pos[name] = [] + if chunk_nb + split_nb in dict_pos[name]: + continue + if method == 'prob': + dict_feats[name].append(softmax(data)) + else: + dict_feats[name].append(data) + dict_pos[name].append(chunk_nb + split_nb) + dict_label[name] = label + # logging.info("Computing final results") + + input_lst = [] + # logging.info(f"{len(dict_feats)}") + for i, item in enumerate(dict_feats): + input_lst.append([i, item, dict_feats[item], dict_label[item]]) + from multiprocessing import Pool + p = Pool(64) + ans = p.map(compute_video, input_lst) + top1 = [x[1] for x in ans] + top5 = [x[2] for x in ans] + # pred = [x[0] for x in ans] + label = [x[3] for x in ans] + final_top1, final_top5 = np.mean(top1), np.mean(top5) + + return final_top1 * 100, final_top5 * 100 + + +# def evaluate_v_cls(model, data, epoch, args, tb_writer=None): +# model.eval() +# dataloader = data['v_cls'] +# args.output_dir = os.path.join(args.log_base_path, 'video_cls') +# os.makedirs(args.output_dir, exist_ok=True) +# if args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs): +# if is_master(args): +# logging.info(f"Eval Epoch: {epoch}, accuracy of zero-shot classification under Kinetics-400 test videos") +# zero_shot_eval(model, dataloader, epoch, args) +# +# torch.distributed.barrier() +# +# if is_master(args): +# # logging.info("Start merging results...") +# final_top1, final_top5 = merge(args.output_dir, args.world_size) +# logging.info(f"\t>>> Acc@1: {final_top1:.2f}%, Acc@5: {final_top5:.2f}%") +# metrics = {'top-1': final_top1, 'top-5': final_top5} +# +# if args.save_logs: +# for name, val in metrics.items(): +# if tb_writer is not None: +# tb_writer.add_scalar(f"val/v_cls/{name}", val, epoch) +# +# with open(os.path.join(args.output_dir, "results.jsonl"), "a+") as f: +# f.write(json.dumps(metrics)) +# f.write("\n") +# +# return metrics + +def evaluate_v_cls(model, data, epoch, args, tb_writer=None): + temp_val_v_cls_data = args.val_v_cls_data + args.val_v_cls_data = list(data.keys()) + assert len(args.val_v_cls_data) == 1 + + + model.eval() + dataloader = data[args.val_v_cls_data[0]] + + + + args.output_dir = os.path.join(args.log_base_path, f'video_cls/{args.val_v_cls_data[0].lower()}') + os.makedirs(args.output_dir, exist_ok=True) + if args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs): + if is_master(args): + logging.info(f"Eval Epoch: {epoch}, accuracy of zero-shot classification under {args.val_v_cls_data[0].lower()} test videos") + zero_shot_eval(model, dataloader, epoch, args) + + torch.distributed.barrier() + + if is_master(args): + logging.info("Start merging results...") + final_top1, final_top5 = merge(args.output_dir, args.world_size) + logging.info(f"\t>>> Acc@1: {final_top1:.2f}%, Acc@5: {final_top5:.2f}%") + metrics = {'top-1': final_top1, 'top-5': final_top5} + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/v_cls/{args.val_v_cls_data[0].lower()}/{name}", val, epoch) + + with open(os.path.join(args.output_dir, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + args.val_v_cls_data = temp_val_v_cls_data + return metrics + + args.val_v_cls_data = temp_val_v_cls_data diff --git a/vl_ret/bpe_simple_vocab_16e6.txt.gz b/vl_ret/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/vl_ret/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/vl_ret/data_dataloaders.py b/vl_ret/data_dataloaders.py new file mode 100644 index 0000000000000000000000000000000000000000..9981c1877bcb2e989a39e5c459c37852dac78992 --- /dev/null +++ b/vl_ret/data_dataloaders.py @@ -0,0 +1,297 @@ +import argparse +import torch +from torch.utils.data import DataLoader + +from data.build_datasets import get_data +from .dataloader_msrvtt_retrieval import MSRVTT_DataLoader +from .dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader +from .dataloader_msvd_retrieval import MSVD_DataLoader +from .dataloader_lsmdc_retrieval import LSMDC_DataLoader +from .dataloader_activitynet_retrieval import ActivityNet_DataLoader +from .dataloader_didemo_retrieval import DiDeMo_DataLoader + +def dataloader_msrvtt_train(args, tokenizer): + msrvtt_dataset = MSRVTT_TrainDataLoader( + csv_path=args.train_csv, + json_path=args.data_path, + features_path=args.features_path, + max_words=args.max_words, + feature_framerate=args.feature_framerate, + tokenizer=tokenizer, + max_frames=args.max_frames, + unfold_sentences=args.expand_msrvtt_sentences, + frame_order=args.train_frame_order, + slice_framepos=args.slice_framepos, + ) + + train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset) + dataloader = DataLoader( + msrvtt_dataset, + batch_size=args.batch_size // args.n_gpu, + num_workers=args.num_thread_reader, + pin_memory=False, + shuffle=(train_sampler is None), + sampler=train_sampler, + drop_last=True, + ) + + return dataloader, len(msrvtt_dataset), train_sampler + +def dataloader_msrvtt_test(args, tokenizer, subset="test"): + msrvtt_testset = MSRVTT_DataLoader( + csv_path=args.val_csv, + features_path=args.features_path, + max_words=args.max_words, + feature_framerate=args.feature_framerate, + tokenizer=tokenizer, + max_frames=args.max_frames, + frame_order=args.eval_frame_order, + slice_framepos=args.slice_framepos, + ) + dataloader_msrvtt = DataLoader( + msrvtt_testset, + batch_size=args.batch_size_val, + num_workers=args.num_thread_reader, + shuffle=False, + drop_last=False, + ) + return dataloader_msrvtt, len(msrvtt_testset) + + +def dataloader_msvd_train(args, tokenizer): + msvd_dataset = MSVD_DataLoader( + subset="train", + data_path=args.data_path, + features_path=args.features_path, + max_words=args.max_words, + feature_framerate=args.feature_framerate, + tokenizer=tokenizer, + max_frames=args.max_frames, + frame_order=args.train_frame_order, + slice_framepos=args.slice_framepos, + ) + + train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset) + dataloader = DataLoader( + msvd_dataset, + batch_size=args.batch_size // args.n_gpu, + num_workers=args.num_thread_reader, + pin_memory=False, + shuffle=(train_sampler is None), + sampler=train_sampler, + drop_last=True, + ) + + return dataloader, len(msvd_dataset), train_sampler + +def dataloader_msvd_test(args, tokenizer, subset="test"): + msvd_testset = MSVD_DataLoader( + subset=subset, + data_path=args.data_path, + features_path=args.features_path, + max_words=args.max_words, + feature_framerate=args.feature_framerate, + tokenizer=tokenizer, + max_frames=args.max_frames, + frame_order=args.eval_frame_order, + slice_framepos=args.slice_framepos, + ) + dataloader_msrvtt = DataLoader( + msvd_testset, + batch_size=args.batch_size_val, + num_workers=args.num_thread_reader, + shuffle=False, + drop_last=False, + ) + return dataloader_msrvtt, len(msvd_testset) + + +def dataloader_lsmdc_train(args, tokenizer): + lsmdc_dataset = LSMDC_DataLoader( + subset="train", + data_path=args.data_path, + features_path=args.features_path, + max_words=args.max_words, + feature_framerate=args.feature_framerate, + tokenizer=tokenizer, + max_frames=args.max_frames, + frame_order=args.train_frame_order, + slice_framepos=args.slice_framepos, + ) + + train_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_dataset) + dataloader = DataLoader( + lsmdc_dataset, + batch_size=args.batch_size // args.n_gpu, + num_workers=args.num_thread_reader, + pin_memory=False, + shuffle=(train_sampler is None), + sampler=train_sampler, + drop_last=True, + ) + + return dataloader, len(lsmdc_dataset), train_sampler + +def dataloader_lsmdc_test(args, tokenizer, subset="test"): + lsmdc_testset = LSMDC_DataLoader( + subset=subset, + data_path=args.data_path, + features_path=args.features_path, + max_words=args.max_words, + feature_framerate=args.feature_framerate, + tokenizer=tokenizer, + max_frames=args.max_frames, + frame_order=args.eval_frame_order, + slice_framepos=args.slice_framepos, + ) + dataloader_msrvtt = DataLoader( + lsmdc_testset, + batch_size=args.batch_size_val, + num_workers=args.num_thread_reader, + shuffle=False, + drop_last=False, + ) + return dataloader_msrvtt, len(lsmdc_testset) + + +def dataloader_activity_train(args, tokenizer): + activity_dataset = ActivityNet_DataLoader( + subset="train", + data_path=args.data_path, + features_path=args.features_path, + max_words=args.max_words, + feature_framerate=args.feature_framerate, + tokenizer=tokenizer, + max_frames=args.max_frames, + frame_order=args.train_frame_order, + slice_framepos=args.slice_framepos, + ) + + train_sampler = torch.utils.data.distributed.DistributedSampler(activity_dataset) + dataloader = DataLoader( + activity_dataset, + batch_size=args.batch_size // args.n_gpu, + num_workers=args.num_thread_reader, + pin_memory=False, + shuffle=(train_sampler is None), + sampler=train_sampler, + drop_last=True, + ) + + return dataloader, len(activity_dataset), train_sampler + +def dataloader_activity_test(args, tokenizer, subset="test"): + activity_testset = ActivityNet_DataLoader( + subset=subset, + data_path=args.data_path, + features_path=args.features_path, + max_words=args.max_words, + feature_framerate=args.feature_framerate, + tokenizer=tokenizer, + max_frames=args.max_frames, + frame_order=args.eval_frame_order, + slice_framepos=args.slice_framepos, + ) + dataloader_msrvtt = DataLoader( + activity_testset, + batch_size=args.batch_size_val, + num_workers=args.num_thread_reader, + shuffle=False, + drop_last=False, + ) + return dataloader_msrvtt, len(activity_testset) + + +def dataloader_didemo_train(args, tokenizer): + didemo_dataset = DiDeMo_DataLoader( + subset="train", + data_path=args.data_path, + features_path=args.features_path, + max_words=args.max_words, + feature_framerate=args.feature_framerate, + tokenizer=tokenizer, + max_frames=args.max_frames, + frame_order=args.train_frame_order, + slice_framepos=args.slice_framepos, + ) + + train_sampler = torch.utils.data.distributed.DistributedSampler(didemo_dataset) + dataloader = DataLoader( + didemo_dataset, + batch_size=args.batch_size // args.n_gpu, + num_workers=args.num_thread_reader, + pin_memory=False, + shuffle=(train_sampler is None), + sampler=train_sampler, + drop_last=True, + ) + + return dataloader, len(didemo_dataset), train_sampler + +def dataloader_didemo_test(args, tokenizer, subset="test"): + didemo_testset = DiDeMo_DataLoader( + subset=subset, + data_path=args.data_path, + features_path=args.features_path, + max_words=args.max_words, + feature_framerate=args.feature_framerate, + tokenizer=tokenizer, + max_frames=args.max_frames, + frame_order=args.eval_frame_order, + slice_framepos=args.slice_framepos, + ) + dataloader_didemo = DataLoader( + didemo_testset, + batch_size=args.batch_size_val, + num_workers=args.num_thread_reader, + shuffle=False, + drop_last=False, + ) + return dataloader_didemo, len(didemo_testset) + + +DATALOADER_DICT = {} +DATALOADER_DICT["msrvtt"] = {"train":dataloader_msrvtt_train, "val":dataloader_msrvtt_test, "test":None} +DATALOADER_DICT["msvd"] = {"train":dataloader_msvd_train, "val":dataloader_msvd_test, "test":dataloader_msvd_test} +DATALOADER_DICT["lsmdc"] = {"train":dataloader_lsmdc_train, "val":dataloader_lsmdc_test, "test":dataloader_lsmdc_test} +DATALOADER_DICT["activity"] = {"train":dataloader_activity_train, "val":dataloader_activity_test, "test":None} +DATALOADER_DICT["didemo"] = {"train":dataloader_didemo_train, "val":dataloader_didemo_test, "test":dataloader_didemo_test} + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--val_vl_ret_data", default="", type=str, help="Point the dataset to finetune.") + parser.add_argument("--val_v_cls_data", default="", type=str, help="Point the dataset to finetune.") + parser.add_argument("--do_train", action='store_true', help="Whether to run training.") + parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") + parser.add_argument('--train_csv', type=str, default='data/.train.csv', help='') + parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='') + parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path') + parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path') + parser.add_argument('--eval_frame_order', type=int, default=0, choices=[0, 1, 2], + help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.") + parser.add_argument('--feature_framerate', type=int, default=1, help='') + parser.add_argument('--slice_framepos', type=int, default=2, choices=[0, 1, 2], + help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.") + parser.add_argument('--max_frames', type=int, default=100, help='') + parser.add_argument('--max_words', type=int, default=20, help='') + parser.add_argument('--batch_size', type=int, default=77, help='') + parser.add_argument('--workers', type=int, default=0, help='') + parser.add_argument('--batch_size_val', type=int, default=0, help='batch size eval') + parser.add_argument('--num_thread_reader', type=int, default=0, help='') + parser.add_argument('--num_frames', type=int, default=8, help='') + args = parser.parse_args() + + args.val_vl_ret_data = 'msrvtt' + args.do_train = False + args.do_eval = True + args.slice_framepos = 2 + args.max_words = 77 + args.train_csv = 'D:/MSRVTT/MSRVTT_train.9k.csv' + args.val_csv = 'D:/MSRVTT/MSRVTT_JSFUSION_test.csv' + args.data_path = 'D:/MSRVTT/MSRVTT_data.json' + args.features_path = 'D:/MSRVTT/videos/all' + + dataloader_msrvtt = get_data(args)["vl_ret"] + for batch in dataloader_msrvtt: + print() \ No newline at end of file diff --git a/vl_ret/dataloader_activitynet_retrieval.py b/vl_ret/dataloader_activitynet_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..5467579c08086e385eb4964b3d5ea37704cf4acb --- /dev/null +++ b/vl_ret/dataloader_activitynet_retrieval.py @@ -0,0 +1,241 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals +from __future__ import print_function + +import os +from torch.utils.data import Dataset +import numpy as np +import json +import math +from .rawvideo_util import RawVideoExtractor + +class ActivityNet_DataLoader(Dataset): + def __init__( + self, + subset, + data_path, + features_path, + tokenizer, + max_words=30, + feature_framerate=1.0, + max_frames=100, + image_resolution=224, + frame_order=0, + slice_framepos=0, + ): + self.data_path = data_path + self.features_path = features_path + self.feature_framerate = feature_framerate + self.max_words = max_words + self.max_frames = max_frames + self.tokenizer = tokenizer + # 0: ordinary order; 1: reverse order; 2: random order. + self.frame_order = frame_order + assert self.frame_order in [0, 1, 2] + # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. + self.slice_framepos = slice_framepos + assert self.slice_framepos in [0, 1, 2] + + self.subset = subset + assert self.subset in ["train", "val"] + + video_id_path_dict = {} + video_id_path_dict["train"] = os.path.join(self.data_path, "train_ids.json") + video_id_path_dict["val"] = os.path.join(self.data_path, "val_ids.json") + + video_json_path_dict = {} + video_json_path_dict["train"] = os.path.join(self.data_path, "train.json") + video_json_path_dict["val"] = os.path.join(self.data_path, "val_1.json") + + pseudo_video_id_list, video_id_list = self._get_video_id_single(video_id_path_dict[self.subset]) + pseudo_caption_dict = self._get_captions_single(video_json_path_dict[self.subset]) + + print("video id list: {}".format(len(video_id_list))) + print("pseudo caption dict: {}".format(len(pseudo_caption_dict.keys()))) + + video_dict = {} + for root, dub_dir, video_files in os.walk(self.features_path): + for video_file in video_files: + video_id_ = ".".join(video_file.split(".")[:-1])[2:] + if video_id_ not in video_id_list: + continue + file_path_ = os.path.join(root, video_file) + video_dict[video_id_] = file_path_ + self.video_dict = video_dict + print("video dict: {}".format(len(video_dict))) + + self.pseudo_video_id_list = pseudo_video_id_list + self.video_id_list = video_id_list + self.pseudo_caption_dict = pseudo_caption_dict + + # Get iterator video ids + self.video_id2idx_dict = {pseudo_video_id: id for id, pseudo_video_id in enumerate(self.pseudo_video_id_list)} + # Get all captions + self.iter2video_pairs_dict = {} + for pseudo_video_id, video_id in zip(self.pseudo_video_id_list, self.video_id_list): + if pseudo_video_id not in self.pseudo_caption_dict or video_id not in self.video_dict: + continue + caption = self.pseudo_caption_dict[pseudo_video_id] + n_caption = len(caption['start']) + for sub_id in range(n_caption): + self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (pseudo_video_id, sub_id) + + self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) + self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", + "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} + + def __len__(self): + return len(self.iter2video_pairs_dict) + + def _get_video_id_from_pseduo(self, pseudo_video_id): + video_id = pseudo_video_id[2:] + return video_id + + def _get_video_id_single(self, path): + pseudo_video_id_list = [] + video_id_list = [] + print('Loading json: {}'.format(path)) + with open(path, 'r') as f: + json_data = json.load(f) + + for pseudo_video_id in json_data: + if pseudo_video_id in pseudo_video_id_list: + print("reduplicate.") + else: + video_id = self._get_video_id_from_pseduo(pseudo_video_id) + pseudo_video_id_list.append(pseudo_video_id) + video_id_list.append(video_id) + return pseudo_video_id_list, video_id_list + + def _get_captions_single(self, path): + pseudo_caption_dict = {} + with open(path, 'r') as f: + json_data = json.load(f) + + for pseudo_video_id, v_ in json_data.items(): + pseudo_caption_dict[pseudo_video_id] = {} + duration = v_["duration"] + pseudo_caption_dict[pseudo_video_id]["start"] = np.array([0], dtype=object) + pseudo_caption_dict[pseudo_video_id]["end"] = np.array([int(math.ceil(float(duration)))], dtype=object) + pseudo_caption_dict[pseudo_video_id]["text"] = np.array([" ".join(v_["sentences"])], dtype=object) + return pseudo_caption_dict + + def _get_text(self, pseudo_video_id, sub_id): + caption = self.pseudo_caption_dict[pseudo_video_id] + k = 1 + r_ind = [sub_id] + + starts = np.zeros(k, dtype=np.int64) + ends = np.zeros(k, dtype=np.int64) + pairs_text = np.zeros((k, self.max_words), dtype=np.int64) + pairs_mask = np.zeros((k, self.max_words), dtype=np.int64) + pairs_segment = np.zeros((k, self.max_words), dtype=np.int64) + + for i in range(k): + # ind = r_ind[i] + # start_, end_ = caption['start'][ind], caption['end'][ind] + # words = self.tokenizer.tokenize(caption['text'][ind]) + # starts[i], ends[i] = start_, end_ + # + # words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words + # total_length_with_CLS = self.max_words - 1 + # if len(words) > total_length_with_CLS: + # words = words[:total_length_with_CLS] + # words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] + # + # input_ids = self.tokenizer.convert_tokens_to_ids(words) + # input_mask = [1] * len(input_ids) + # segment_ids = [0] * len(input_ids) + + + ind = r_ind[i] + start_, end_ = caption['start'][ind], caption['end'][ind] + output = self.tokenizer(caption['text'][ind]) + starts[i], ends[i] = start_, end_ + + input_ids = output[0].squeeze() + input_mask = output[1].squeeze() + segment_ids = [0] * len(input_ids) + + + while len(input_ids) < self.max_words: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == self.max_words + assert len(input_mask) == self.max_words + assert len(segment_ids) == self.max_words + + pairs_text[i] = np.array(input_ids) + pairs_mask[i] = np.array(input_mask) + pairs_segment[i] = np.array(segment_ids) + + return pairs_text, pairs_mask, pairs_segment, starts, ends + + def _get_rawvideo(self, idx, s, e): + video_mask = np.zeros((len(s), self.max_frames), dtype=np.int64) + max_video_length = [0] * len(s) + + # Pair x L x T x 3 x H x W + video = np.zeros((len(s), self.max_frames, 1, 3, + self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float32) + video_path = self.video_dict[idx] + try: + for i in range(len(s)): + start_time = int(s[i]) + end_time = int(e[i]) + start_time = start_time if start_time >= 0. else 0. + end_time = end_time if end_time >= 0. else 0. + if start_time > end_time: + start_time, end_time = end_time, start_time + elif start_time == end_time: + end_time = end_time + 1 + + # Should be optimized by gathering all asking of this video + raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) + raw_video_data = raw_video_data['video'] + # print('raw_video_data', raw_video_data.shape) + + if len(raw_video_data.shape) > 3: + raw_video_data_clip = raw_video_data + # L x T x 3 x H x W + raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) + if self.max_frames < raw_video_slice.shape[0]: + if self.slice_framepos == 0: + video_slice = raw_video_slice[:self.max_frames, ...] + elif self.slice_framepos == 1: + video_slice = raw_video_slice[-self.max_frames:, ...] + else: + sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) + # print('sample_indx', raw_video_slice.shape[0], sample_indx) + video_slice = raw_video_slice[sample_indx, ...] + else: + video_slice = raw_video_slice + + video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) + + slice_len = video_slice.shape[0] + max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len + if slice_len < 1: + pass + else: + video[i][:slice_len, ...] = video_slice + else: + print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) + except Exception as excep: + print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) + raise excep + + for i, v_length in enumerate(max_video_length): + video_mask[i][:v_length] = [1] * v_length + + return video, video_mask + + def __getitem__(self, feature_idx): + pseudo_video_id, sub_id = self.iter2video_pairs_dict[feature_idx] + idx = self.video_id2idx_dict[pseudo_video_id] + + pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(pseudo_video_id, sub_id) + video, video_mask = self._get_rawvideo(self.video_id_list[idx], starts, ends) + return pairs_text, pairs_mask, pairs_segment, video, video_mask \ No newline at end of file diff --git a/vl_ret/dataloader_didemo_retrieval.py b/vl_ret/dataloader_didemo_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..11f30557b75208face7908d505e05f0542cbf864 --- /dev/null +++ b/vl_ret/dataloader_didemo_retrieval.py @@ -0,0 +1,238 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals +from __future__ import print_function + +import os +from torch.utils.data import Dataset +import numpy as np +import json +from .rawvideo_util import RawVideoExtractor + +class DiDeMo_DataLoader(Dataset): + def __init__( + self, + subset, + data_path, + features_path, + tokenizer, + max_words=30, + feature_framerate=1.0, + max_frames=100, + image_resolution=224, + frame_order=0, + slice_framepos=0, + ): + self.data_path = data_path + self.features_path = features_path + self.feature_framerate = feature_framerate + self.max_words = max_words + self.max_frames = max_frames + self.tokenizer = tokenizer + # 0: ordinary order; 1: reverse order; 2: random order. + self.frame_order = frame_order + assert self.frame_order in [0, 1, 2] + # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. + self.slice_framepos = slice_framepos + assert self.slice_framepos in [0, 1, 2] + + self.subset = subset + assert self.subset in ["train", "val", "test"] + + video_id_path_dict = {} + video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") + video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") + video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") + + video_json_path_dict = {} + video_json_path_dict["train"] = os.path.join(self.data_path, "train_data.json") + video_json_path_dict["val"] = os.path.join(self.data_path, "val_data.json") + video_json_path_dict["test"] = os.path.join(self.data_path, "test_data.json") + + with open(video_id_path_dict[self.subset], 'r') as fp: + video_ids = [itm.strip() for itm in fp.readlines()] + + caption_dict = {} + with open(video_json_path_dict[self.subset], 'r') as f: + json_data = json.load(f) + for itm in json_data: + description = itm["description"] + times = itm["times"] + video = itm["video"] + if video not in video_ids: + continue + + # each video is split into 5-second temporal chunks + # average the points from each annotator + start_ = np.mean([t_[0] for t_ in times]) * 5 + end_ = (np.mean([t_[1] for t_ in times]) + 1) * 5 + if video in caption_dict: + caption_dict[video]["start"].append(start_) + caption_dict[video]["end"].append(end_) + caption_dict[video]["text"].append(description) + else: + caption_dict[video] = {} + caption_dict[video]["start"] = [start_] + caption_dict[video]["end"] = [end_] + caption_dict[video]["text"] = [description] + + for k_ in caption_dict.keys(): + caption_dict[k_]["start"] = [0] + # trick to save time on obtaining each video length + # [https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md]: + # Some videos are longer than 30 seconds. These videos were truncated to 30 seconds during annotation. + caption_dict[k_]["end"] = [31] + caption_dict[k_]["text"] = [" ".join(caption_dict[k_]["text"])] + + video_dict = {} + for root, dub_dir, video_files in os.walk(self.features_path): + for video_file in video_files: + video_id_ = os.path.splitext(video_file)[0] ###############3 + if video_id_ not in video_ids: + continue + file_path_ = os.path.join(root, video_file) + video_dict[video_id_] = file_path_ + + self.caption_dict = caption_dict + self.video_dict = video_dict + video_ids = list(set(video_ids) & set(self.caption_dict.keys()) & set(self.video_dict.keys())) + + # Get all captions + self.iter2video_pairs_dict = {} + for video_id in self.caption_dict.keys(): + if video_id not in video_ids: + continue + caption = self.caption_dict[video_id] + n_caption = len(caption['start']) + for sub_id in range(n_caption): + self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (video_id, sub_id) + + self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) + self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", + "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} + + def __len__(self): + return len(self.iter2video_pairs_dict) + + def _get_text(self, video_id, sub_id): + caption = self.caption_dict[video_id] + k = 1 + r_ind = [sub_id] + + starts = np.zeros(k, dtype=np.int64) + ends = np.zeros(k, dtype=np.int64) + pairs_text = np.zeros((k, self.max_words), dtype=np.int64) + pairs_mask = np.zeros((k, self.max_words), dtype=np.int64) + pairs_segment = np.zeros((k, self.max_words), dtype=np.int64) + + for i in range(k): + # ind = r_ind[i] + # start_, end_ = caption['start'][ind], caption['end'][ind] + # words = self.tokenizer.tokenize(caption['text'][ind]) + # starts[i], ends[i] = start_, end_ + # + # words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words + # total_length_with_CLS = self.max_words - 1 + # if len(words) > total_length_with_CLS: + # words = words[:total_length_with_CLS] + # words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] + # + # input_ids = self.tokenizer.convert_tokens_to_ids(words) + # input_mask = [1] * len(input_ids) + # segment_ids = [0] * len(input_ids) + + + + ind = r_ind[i] + start_, end_ = caption['start'][ind], caption['end'][ind] + output = self.tokenizer(caption['text'][ind]) + starts[i], ends[i] = start_, end_ + + input_ids = output[0].squeeze() + input_mask = output[1].squeeze() + segment_ids = [0] * len(input_ids) + + + + while len(input_ids) < self.max_words: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == self.max_words + assert len(input_mask) == self.max_words + assert len(segment_ids) == self.max_words + + pairs_text[i] = np.array(input_ids) + pairs_mask[i] = np.array(input_mask) + pairs_segment[i] = np.array(segment_ids) + + return pairs_text, pairs_mask, pairs_segment, starts, ends + + def _get_rawvideo(self, idx, s, e): + video_mask = np.zeros((len(s), self.max_frames), dtype=np.int64) + max_video_length = [0] * len(s) + + # Pair x L x T x 3 x H x W + video = np.zeros((len(s), self.max_frames, 1, 3, + self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float32) + video_path = self.video_dict[idx] + + try: + for i in range(len(s)): + start_time = int(s[i]) + end_time = int(e[i]) + start_time = start_time if start_time >= 0. else 0. + end_time = end_time if end_time >= 0. else 0. + if start_time > end_time: + start_time, end_time = end_time, start_time + elif start_time == end_time: + end_time = end_time + 1 + + cache_id = "{}_{}_{}".format(video_path, start_time, end_time) + # Should be optimized by gathering all asking of this video + raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) + raw_video_data = raw_video_data['video'] + # print('raw_video_data', raw_video_data.shape) + + if len(raw_video_data.shape) > 3: + raw_video_data_clip = raw_video_data + # L x T x 3 x H x W + raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) + if self.max_frames < raw_video_slice.shape[0]: + if self.slice_framepos == 0: + video_slice = raw_video_slice[:self.max_frames, ...] + elif self.slice_framepos == 1: + video_slice = raw_video_slice[-self.max_frames:, ...] + else: + sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) + # print('sample_indx', raw_video_slice.shape[0], sample_indx) + video_slice = raw_video_slice[sample_indx, ...] + else: + video_slice = raw_video_slice + + video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) + + slice_len = video_slice.shape[0] + max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len + if slice_len < 1: + pass + else: + video[i][:slice_len, ...] = video_slice + else: + print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) + except Exception as excep: + print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) + pass + # raise e + + for i, v_length in enumerate(max_video_length): + video_mask[i][:v_length] = [1] * v_length + + return video, video_mask + + def __getitem__(self, feature_idx): + video_id, sub_id = self.iter2video_pairs_dict[feature_idx] + + pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(video_id, sub_id) + video, video_mask = self._get_rawvideo(video_id, starts, ends) + return pairs_text, pairs_mask, pairs_segment, video, video_mask \ No newline at end of file diff --git a/vl_ret/dataloader_lsmdc_retrieval.py b/vl_ret/dataloader_lsmdc_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..3575002d05bf94273caa561cc8b71831a4de5cb2 --- /dev/null +++ b/vl_ret/dataloader_lsmdc_retrieval.py @@ -0,0 +1,208 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals +from __future__ import print_function + +import os +from torch.utils.data import Dataset +import numpy as np +import json +import math +from .rawvideo_util import RawVideoExtractor + +class LSMDC_DataLoader(Dataset): + """LSMDC dataset loader.""" + def __init__( + self, + subset, + data_path, + features_path, + tokenizer, + max_words=30, + feature_framerate=1.0, + max_frames=100, + image_resolution=224, + frame_order=0, + slice_framepos=0, + ): + self.data_path = data_path + self.features_path = features_path + self.feature_framerate = feature_framerate + self.max_words = max_words + self.max_frames = max_frames + self.tokenizer = tokenizer + # 0: ordinary order; 1: reverse order; 2: random order. + self.frame_order = frame_order + assert self.frame_order in [0, 1, 2] + # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. + self.slice_framepos = slice_framepos + assert self.slice_framepos in [0, 1, 2] + + self.subset = subset + assert self.subset in ["train", "val", "test"] + + video_json_path_dict = {} + video_json_path_dict["train"] = os.path.join(self.data_path, "LSMDC16_annos_training.csv") + video_json_path_dict["val"] = os.path.join(self.data_path, "LSMDC16_annos_val.csv") + video_json_path_dict["test"] = os.path.join(self.data_path, "LSMDC16_challenge_1000_publictect.csv") + + # \t\t\t\t\t + # is not a unique identifier, i.e. the same can be associated with multiple sentences. + # However, LSMDC16_challenge_1000_publictect.csv has no repeat instances + video_id_list = [] + caption_dict = {} + with open(video_json_path_dict[self.subset], 'r') as fp: + for line in fp: + line = line.strip() + line_split = line.split("\t") + assert len(line_split) == 6 + clip_id, start_aligned, end_aligned, start_extracted, end_extracted, sentence = line_split + caption_dict[len(caption_dict)] = (clip_id, sentence) + if clip_id not in video_id_list: video_id_list.append(clip_id) + + video_dict = {} + for root, dub_dir, video_files in os.walk(self.features_path): + for video_file in video_files: + video_id_ = ".".join(video_file.split(".")[:-1]) + if video_id_ not in video_id_list: + continue + file_path_ = os.path.join(root, video_file) + video_dict[video_id_] = file_path_ + + self.video_dict = video_dict + + # Get all captions + self.iter2video_pairs_dict = {} + for clip_id, sentence in caption_dict.values(): + if clip_id not in self.video_dict: + continue + self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (clip_id, sentence) + + self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) + self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", + "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} + + def __len__(self): + return len(self.iter2video_pairs_dict) + + def _get_video_id_from_pseduo(self, pseudo_video_id): + video_id = pseudo_video_id[2:] + return video_id + + def _get_video_id_single(self, path): + pseudo_video_id_list = [] + video_id_list = [] + print('Loading json: {}'.format(path)) + with open(path, 'r') as f: + json_data = json.load(f) + + for pseudo_video_id in json_data: + if pseudo_video_id in pseudo_video_id_list: + print("reduplicate.") + else: + video_id = self._get_video_id_from_pseduo(pseudo_video_id) + pseudo_video_id_list.append(pseudo_video_id) + video_id_list.append(video_id) + return pseudo_video_id_list, video_id_list + + def _get_captions_single(self, path): + pseudo_caption_dict = {} + with open(path, 'r') as f: + json_data = json.load(f) + + for pseudo_video_id, v_ in json_data.items(): + pseudo_caption_dict[pseudo_video_id] = {} + timestamps = v_["timestamps"] + pseudo_caption_dict[pseudo_video_id]["start"] = \ + np.array([int(math.floor(float(itm[0]))) for itm in timestamps], dtype=object) + pseudo_caption_dict[pseudo_video_id]["end"] = \ + np.array([int(math.ceil(float(itm[1]))) for itm in timestamps], dtype=object) + pseudo_caption_dict[pseudo_video_id]["text"] = np.array(v_["sentences"], dtype=object) + return pseudo_caption_dict + + def _get_text(self, video_id, caption): + k = 1 + choice_video_ids = [video_id] + pairs_text = np.zeros((k, self.max_words), dtype=np.int64) + pairs_mask = np.zeros((k, self.max_words), dtype=np.int64) + pairs_segment = np.zeros((k, self.max_words), dtype=np.int64) + + for i, video_id in enumerate(choice_video_ids): + words = self.tokenizer.tokenize(caption) + + words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words + total_length_with_CLS = self.max_words - 1 + if len(words) > total_length_with_CLS: + words = words[:total_length_with_CLS] + words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] + + input_ids = self.tokenizer.convert_tokens_to_ids(words) + input_mask = [1] * len(input_ids) + segment_ids = [0] * len(input_ids) + while len(input_ids) < self.max_words: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == self.max_words + assert len(input_mask) == self.max_words + assert len(segment_ids) == self.max_words + + pairs_text[i] = np.array(input_ids) + pairs_mask[i] = np.array(input_mask) + pairs_segment[i] = np.array(segment_ids) + + return pairs_text, pairs_mask, pairs_segment, choice_video_ids + + def _get_rawvideo(self, choice_video_ids): + video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.int64) + max_video_length = [0] * len(choice_video_ids) + + # Pair x L x T x 3 x H x W + video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, + self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float32) + + try: + for i, video_id in enumerate(choice_video_ids): + video_path = self.video_dict[video_id] + + raw_video_data = self.rawVideoExtractor.get_video_data(video_path) + raw_video_data = raw_video_data['video'] + + if len(raw_video_data.shape) > 3: + raw_video_data_clip = raw_video_data + # L x T x 3 x H x W + raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) + if self.max_frames < raw_video_slice.shape[0]: + if self.slice_framepos == 0: + video_slice = raw_video_slice[:self.max_frames, ...] + elif self.slice_framepos == 1: + video_slice = raw_video_slice[-self.max_frames:, ...] + else: + sample_indx = np.linspace(0, raw_video_slice.shape[0]-1, num=self.max_frames, dtype=int) + video_slice = raw_video_slice[sample_indx, ...] + else: + video_slice = raw_video_slice + + video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) + + slice_len = video_slice.shape[0] + max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len + if slice_len < 1: + pass + else: + video[i][:slice_len, ...] = video_slice + else: + print("video path: {} error. video id: {}".format(video_path, video_id)) + except Exception as excep: + print("Video ids: {}".format(choice_video_ids)) + raise excep + + for i, v_length in enumerate(max_video_length): + video_mask[i][:v_length] = [1] * v_length + return video, video_mask + + def __getitem__(self, feature_idx): + clip_id, sentence = self.iter2video_pairs_dict[feature_idx] + pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(clip_id, sentence) + video, video_mask = self._get_rawvideo(choice_video_ids) + return pairs_text, pairs_mask, pairs_segment, video, video_mask \ No newline at end of file diff --git a/vl_ret/dataloader_msrvtt_retrieval.py b/vl_ret/dataloader_msrvtt_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..53ba986c9daf8919bb728536fee50b820534e557 --- /dev/null +++ b/vl_ret/dataloader_msrvtt_retrieval.py @@ -0,0 +1,311 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals +from __future__ import print_function + +import os +from torch.utils.data import Dataset +import numpy as np +import pandas as pd +from collections import defaultdict +import json +import random +from .rawvideo_util import RawVideoExtractor + +class MSRVTT_DataLoader(Dataset): + """MSRVTT dataset loader.""" + def __init__( + self, + csv_path, + features_path, + tokenizer, + max_words=30, + feature_framerate=1.0, + max_frames=100, + image_resolution=224, + frame_order=0, + slice_framepos=0, + ): + self.data = pd.read_csv(csv_path) + self.features_path = features_path + self.feature_framerate = feature_framerate + self.max_words = max_words + self.max_frames = max_frames + self.tokenizer = tokenizer + # 0: ordinary order; 1: reverse order; 2: random order. + self.frame_order = frame_order + assert self.frame_order in [0, 1, 2] + # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. + self.slice_framepos = slice_framepos + assert self.slice_framepos in [0, 1, 2] + + self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) + self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", + "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} + + def __len__(self): + return len(self.data) + + def _get_text(self, video_id, sentence): + choice_video_ids = [video_id] + n_caption = len(choice_video_ids) + + k = n_caption + pairs_text = np.zeros((k, self.max_words), dtype=np.int64) + pairs_mask = np.zeros((k, self.max_words), dtype=np.int64) + pairs_segment = np.zeros((k, self.max_words), dtype=np.int64) + + for i, video_id in enumerate(choice_video_ids): + # words = self.tokenizer.tokenize(sentence) + # + # words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words + # total_length_with_CLS = self.max_words - 1 + # if len(words) > total_length_with_CLS: + # words = words[:total_length_with_CLS] + # words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] + # + # input_ids = self.tokenizer.convert_tokens_to_ids(words) + # input_mask = [1] * len(input_ids) + # segment_ids = [0] * len(input_ids) + + + output = self.tokenizer(sentence) + + input_ids = output[0].squeeze() + input_mask = output[1].squeeze() + segment_ids = [0] * len(input_ids) + + + while len(input_ids) < self.max_words: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == self.max_words + assert len(input_mask) == self.max_words + assert len(segment_ids) == self.max_words + + pairs_text[i] = np.array(input_ids) + pairs_mask[i] = np.array(input_mask) + pairs_segment[i] = np.array(segment_ids) + + return pairs_text, pairs_mask, pairs_segment, choice_video_ids + + def _get_rawvideo(self, choice_video_ids): + video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.int64) + max_video_length = [0] * len(choice_video_ids) + + # Pair x L x T x 3 x H x W + video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, + self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float32) + + for i, video_id in enumerate(choice_video_ids): + # Individual for YoucokII dataset, due to it video format + video_path = os.path.join(self.features_path, "{}.mp4".format(video_id)) + if os.path.exists(video_path) is False: + video_path = video_path.replace(".mp4", ".webm") + + raw_video_data = self.rawVideoExtractor.get_video_data(video_path) + raw_video_data = raw_video_data['video'] + # print('raw_video_data', raw_video_data.shape) + if len(raw_video_data.shape) > 3: + raw_video_data_clip = raw_video_data + # L x T x 3 x H x W + raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) + if self.max_frames < raw_video_slice.shape[0]: + if self.slice_framepos == 0: + video_slice = raw_video_slice[:self.max_frames, ...] + elif self.slice_framepos == 1: + video_slice = raw_video_slice[-self.max_frames:, ...] + else: + sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) + # print('sample_indx', raw_video_slice.shape[0], sample_indx) + video_slice = raw_video_slice[sample_indx, ...] + else: + video_slice = raw_video_slice + + video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) + + slice_len = video_slice.shape[0] + max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len + if slice_len < 1: + pass + else: + video[i][:slice_len, ...] = video_slice + else: + print("video path: {} error. video id: {}".format(video_path, video_id)) + + for i, v_length in enumerate(max_video_length): + video_mask[i][:v_length] = [1] * v_length + + return video, video_mask + + def __getitem__(self, idx): + video_id = self.data['video_id'].values[idx] + sentence = self.data['sentence'].values[idx] + + pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, sentence) + video, video_mask = self._get_rawvideo(choice_video_ids) + return pairs_text, pairs_mask, pairs_segment, video, video_mask + +class MSRVTT_TrainDataLoader(Dataset): + """MSRVTT train dataset loader.""" + def __init__( + self, + csv_path, + json_path, + features_path, + tokenizer, + max_words=30, + feature_framerate=1.0, + max_frames=100, + unfold_sentences=False, + image_resolution=224, + frame_order=0, + slice_framepos=0, + ): + self.csv = pd.read_csv(csv_path) + self.data = json.load(open(json_path, 'r')) + self.features_path = features_path + self.feature_framerate = feature_framerate + self.max_words = max_words + self.max_frames = max_frames + self.tokenizer = tokenizer + # 0: ordinary order; 1: reverse order; 2: random order. + self.frame_order = frame_order + assert self.frame_order in [0, 1, 2] + # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. + self.slice_framepos = slice_framepos + assert self.slice_framepos in [0, 1, 2] + + self.unfold_sentences = unfold_sentences + self.sample_len = 0 + if self.unfold_sentences: + train_video_ids = list(self.csv['video_id'].values) + self.sentences_dict = {} + for itm in self.data['sentences']: + if itm['video_id'] in train_video_ids: + self.sentences_dict[len(self.sentences_dict)] = (itm['video_id'], itm['caption']) + self.sample_len = len(self.sentences_dict) + else: + num_sentences = 0 + self.sentences = defaultdict(list) + s_video_id_set = set() + for itm in self.data['sentences']: + self.sentences[itm['video_id']].append(itm['caption']) + num_sentences += 1 + s_video_id_set.add(itm['video_id']) + + # Use to find the clips in the same video + self.parent_ids = {} + self.children_video_ids = defaultdict(list) + for itm in self.data['videos']: + vid = itm["video_id"] + url_posfix = itm["url"].split("?v=")[-1] + self.parent_ids[vid] = url_posfix + self.children_video_ids[url_posfix].append(vid) + self.sample_len = len(self.csv) + + self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) + self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", + "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} + + def __len__(self): + return self.sample_len + + def _get_text(self, video_id, caption=None): + k = 1 + choice_video_ids = [video_id] + pairs_text = np.zeros((k, self.max_words), dtype=np.int64) + pairs_mask = np.zeros((k, self.max_words), dtype=np.int64) + pairs_segment = np.zeros((k, self.max_words), dtype=np.int64) + + for i, video_id in enumerate(choice_video_ids): + if caption is not None: + words = self.tokenizer.tokenize(caption) + else: + words = self._get_single_text(video_id) + + words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words + total_length_with_CLS = self.max_words - 1 + if len(words) > total_length_with_CLS: + words = words[:total_length_with_CLS] + words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] + + input_ids = self.tokenizer.convert_tokens_to_ids(words) + input_mask = [1] * len(input_ids) + segment_ids = [0] * len(input_ids) + while len(input_ids) < self.max_words: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == self.max_words + assert len(input_mask) == self.max_words + assert len(segment_ids) == self.max_words + + pairs_text[i] = np.array(input_ids) + pairs_mask[i] = np.array(input_mask) + pairs_segment[i] = np.array(segment_ids) + + return pairs_text, pairs_mask, pairs_segment, choice_video_ids + + def _get_single_text(self, video_id): + rind = random.randint(0, len(self.sentences[video_id]) - 1) + caption = self.sentences[video_id][rind] + words = self.tokenizer.tokenize(caption) + return words + + def _get_rawvideo(self, choice_video_ids): + video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.int64) + max_video_length = [0] * len(choice_video_ids) + + # Pair x L x T x 3 x H x W + video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, + self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float32) + + for i, video_id in enumerate(choice_video_ids): + # Individual for YoucokII dataset, due to it video format + video_path = os.path.join(self.features_path, "{}.mp4".format(video_id)) + if os.path.exists(video_path) is False: + video_path = video_path.replace(".mp4", ".webm") + + raw_video_data = self.rawVideoExtractor.get_video_data(video_path) + raw_video_data = raw_video_data['video'] + if len(raw_video_data.shape) > 3: + raw_video_data_clip = raw_video_data + # L x T x 3 x H x W + raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) + if self.max_frames < raw_video_slice.shape[0]: + if self.slice_framepos == 0: + video_slice = raw_video_slice[:self.max_frames, ...] + elif self.slice_framepos == 1: + video_slice = raw_video_slice[-self.max_frames:, ...] + else: + sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) + video_slice = raw_video_slice[sample_indx, ...] + else: + video_slice = raw_video_slice + + video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) + + slice_len = video_slice.shape[0] + max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len + if slice_len < 1: + pass + else: + video[i][:slice_len, ...] = video_slice + else: + print("video path: {} error. video id: {}".format(video_path, video_id)) + + for i, v_length in enumerate(max_video_length): + video_mask[i][:v_length] = [1] * v_length + + return video, video_mask + + def __getitem__(self, idx): + if self.unfold_sentences: + video_id, caption = self.sentences_dict[idx] + else: + video_id, caption = self.csv['video_id'].values[idx], None + pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption) + video, video_mask = self._get_rawvideo(choice_video_ids) + return pairs_text, pairs_mask, pairs_segment, video, video_mask \ No newline at end of file diff --git a/vl_ret/dataloader_msvd_retrieval.py b/vl_ret/dataloader_msvd_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..477afed0f320880ae8cb83daa96a264e84184f3f --- /dev/null +++ b/vl_ret/dataloader_msvd_retrieval.py @@ -0,0 +1,191 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals +from __future__ import print_function + +import os +from torch.utils.data import Dataset +import numpy as np +import pickle +from .rawvideo_util import RawVideoExtractor + +class MSVD_DataLoader(Dataset): + """MSVD dataset loader.""" + def __init__( + self, + subset, + data_path, + features_path, + tokenizer, + max_words=30, + feature_framerate=1.0, + max_frames=100, + image_resolution=224, + frame_order=0, + slice_framepos=0, + ): + self.data_path = data_path + self.features_path = features_path + self.feature_framerate = feature_framerate + self.max_words = max_words + self.max_frames = max_frames + self.tokenizer = tokenizer + # 0: ordinary order; 1: reverse order; 2: random order. + self.frame_order = frame_order + assert self.frame_order in [0, 1, 2] + # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. + self.slice_framepos = slice_framepos + assert self.slice_framepos in [0, 1, 2] + + self.subset = subset + assert self.subset in ["train", "val", "test"] + video_id_path_dict = {} + video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") + video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") + video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") + caption_file = os.path.join(self.data_path, "raw-captions.pkl") + + with open(video_id_path_dict[self.subset], 'r') as fp: + video_ids = [itm.strip() for itm in fp.readlines()] + + with open(caption_file, 'rb') as f: + captions = pickle.load(f) + + video_dict = {} + for root, dub_dir, video_files in os.walk(self.features_path): + for video_file in video_files: + video_id_ = ".".join(video_file.split(".")[:-1]) + if video_id_ not in video_ids: + continue + file_path_ = os.path.join(root, video_file) + video_dict[video_id_] = file_path_ + self.video_dict = video_dict + + self.sample_len = 0 + self.sentences_dict = {} + self.cut_off_points = [] + for video_id in video_ids: + assert video_id in captions + for cap in captions[video_id]: + cap_txt = " ".join(cap) + self.sentences_dict[len(self.sentences_dict)] = (video_id, cap_txt) + self.cut_off_points.append(len(self.sentences_dict)) + + ## below variables are used to multi-sentences retrieval + # self.cut_off_points: used to tag the label when calculate the metric + # self.sentence_num: used to cut the sentence representation + # self.video_num: used to cut the video representation + self.multi_sentence_per_video = True # !!! important tag for eval + if self.subset == "val" or self.subset == "test": + self.sentence_num = len(self.sentences_dict) + self.video_num = len(video_ids) + assert len(self.cut_off_points) == self.video_num + print("For {}, sentence number: {}".format(self.subset, self.sentence_num)) + print("For {}, video number: {}".format(self.subset, self.video_num)) + + print("Video number: {}".format(len(self.video_dict))) + print("Total Paire: {}".format(len(self.sentences_dict))) + + self.sample_len = len(self.sentences_dict) + self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) + self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", + "MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} + + def __len__(self): + return self.sample_len + + def _get_text(self, video_id, caption): + k = 1 + choice_video_ids = [video_id] + pairs_text = np.zeros((k, self.max_words), dtype=np.int64) + pairs_mask = np.zeros((k, self.max_words), dtype=np.int64) + pairs_segment = np.zeros((k, self.max_words), dtype=np.int64) + + for i, video_id in enumerate(choice_video_ids): + # words = self.tokenizer.tokenize(caption) + # + # words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words + # total_length_with_CLS = self.max_words - 1 + # if len(words) > total_length_with_CLS: + # words = words[:total_length_with_CLS] + # words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] + # + # input_ids = self.tokenizer.convert_tokens_to_ids(words) + # input_mask = [1] * len(input_ids) + # segment_ids = [0] * len(input_ids) + + + output = self.tokenizer(caption) + + input_ids = output[0].squeeze() + input_mask = output[1].squeeze() + segment_ids = [0] * len(input_ids) + + + while len(input_ids) < self.max_words: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == self.max_words + assert len(input_mask) == self.max_words + assert len(segment_ids) == self.max_words + + pairs_text[i] = np.array(input_ids) + pairs_mask[i] = np.array(input_mask) + pairs_segment[i] = np.array(segment_ids) + + return pairs_text, pairs_mask, pairs_segment, choice_video_ids + + def _get_rawvideo(self, choice_video_ids): + video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.int64) + max_video_length = [0] * len(choice_video_ids) + + # Pair x L x T x 3 x H x W + video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, + self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float32) + + for i, video_id in enumerate(choice_video_ids): + video_path = self.video_dict[video_id] + + raw_video_data = self.rawVideoExtractor.get_video_data(video_path) + raw_video_data = raw_video_data['video'] + # print('raw_video_data', raw_video_data.shape) + + if len(raw_video_data.shape) > 3: + raw_video_data_clip = raw_video_data + # L x T x 3 x H x W + raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) + if self.max_frames < raw_video_slice.shape[0]: + if self.slice_framepos == 0: + video_slice = raw_video_slice[:self.max_frames, ...] + elif self.slice_framepos == 1: + video_slice = raw_video_slice[-self.max_frames:, ...] + else: + sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) + # print('sample_indx', raw_video_slice.shape[0], sample_indx) + video_slice = raw_video_slice[sample_indx, ...] + else: + video_slice = raw_video_slice + + video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) + + slice_len = video_slice.shape[0] + max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len + if slice_len < 1: + pass + else: + video[i][:slice_len, ...] = video_slice + else: + print("video path: {} error. video id: {}".format(video_path, video_id)) + + for i, v_length in enumerate(max_video_length): + video_mask[i][:v_length] = [1] * v_length + + return video, video_mask + + def __getitem__(self, idx): + video_id, caption = self.sentences_dict[idx] + + pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption) + video, video_mask = self._get_rawvideo(choice_video_ids) + return pairs_text, pairs_mask, pairs_segment, video, video_mask \ No newline at end of file diff --git a/vl_ret/metrics.py b/vl_ret/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..708f8c9aec43a3b4b768f6a22739a268d8c38a16 --- /dev/null +++ b/vl_ret/metrics.py @@ -0,0 +1,70 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals +from __future__ import print_function + +import numpy as np +import torch + +def compute_metrics(x): + sx = np.sort(-x, axis=1) + d = np.diag(-x) + d = d[:, np.newaxis] + ind = sx - d + ind = np.where(ind == 0) + ind = ind[1] + metrics = {} + metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) + metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) + metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) + metrics['MR'] = np.median(ind) + 1 + metrics["MedianR"] = metrics['MR'] + metrics["MeanR"] = np.mean(ind) + 1 + # metrics["cols"] = [int(i) for i in list(ind)] + return metrics + +def print_computed_metrics(metrics): + r1 = metrics['R1'] + r5 = metrics['R5'] + r10 = metrics['R10'] + mr = metrics['MR'] + print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr)) + +# below two functions directly come from: https://github.com/Deferf/Experiments +def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10]): + if not torch.is_tensor(sim_tensor): + sim_tensor = torch.tensor(sim_tensor) + + # Permute sim_tensor so it represents a sequence of text-video similarity matrices. + # Then obtain the double argsort to position the rank on the diagonal + stacked_sim_matrices = sim_tensor.permute(1, 0, 2) + first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True) + second_argsort = torch.argsort(first_argsort, dim = -1, descending= False) + + # Extracts ranks i.e diagonals + ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2)) + + # Now we need to extract valid ranks, as some belong to inf padding values + permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2)) + mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data)) + valid_ranks = ranks[mask] + # A quick dimension check validates our results, there may be other correctness tests pending + # Such as dot product localization, but that is for other time. + #assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict]) + if not torch.is_tensor(valid_ranks): + valid_ranks = torch.tensor(valid_ranks) + results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k} + results["MedianR"] = float(torch.median(valid_ranks + 1)) + results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1)) + results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1)) + results['MR'] = results["MedianR"] + return results + +def tensor_video_to_text_sim(sim_tensor): + if not torch.is_tensor(sim_tensor): + sim_tensor = torch.tensor(sim_tensor) + # Code to avoid nans + sim_tensor[sim_tensor != sim_tensor] = float('-inf') + # Forms a similarity matrix for use with rank at k + values, _ = torch.max(sim_tensor, dim=1, keepdim=True) + return torch.squeeze(values).T diff --git a/vl_ret/rawvideo_util.py b/vl_ret/rawvideo_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ea4eb16b2526ab0478b9b909bd221223fdc2d1 --- /dev/null +++ b/vl_ret/rawvideo_util.py @@ -0,0 +1,141 @@ +import torch as th +import numpy as np +from PIL import Image +# pytorch=1.7.1 +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +# pip install opencv-python +import cv2 + +class RawVideoExtractorCV2(): + def __init__(self, centercrop=False, size=224, framerate=-1, ): + self.centercrop = centercrop + self.size = size + self.framerate = framerate + self.transform = self._transform(self.size) + + def _transform(self, n_px): + return Compose([ + Resize(n_px, interpolation=Image.BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + # def video_to_tensor(self, video_file, preprocess, sample_fp=0, start_time=None, end_time=None): + # if start_time is not None or end_time is not None: + # assert isinstance(start_time, int) and isinstance(end_time, int) \ + # and start_time > -1 and end_time > start_time + # assert sample_fp > -1 + # + # # Samples a frame sample_fp X frames. + # cap = cv2.VideoCapture(video_file) + # frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + # fps = int(cap.get(cv2.CAP_PROP_FPS)) + # + # total_duration = (frameCount + fps - 1) // fps + # start_sec, end_sec = 0, total_duration + # + # if start_time is not None: + # start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration + # cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) + # + # interval = 1 + # if sample_fp > 0: # 1 + # interval = fps // sample_fp # fps + # else: + # sample_fp = fps + # if interval == 0: interval = 1 + # + # inds = [ind for ind in np.arange(0, fps, interval)] # miao + # assert len(inds) >= sample_fp + # inds = inds[:sample_fp] + # + # ret = True + # images, included = [], [] + # + # for sec in np.arange(start_sec, end_sec + 1): + # if not ret: break + # sec_base = int(sec * fps) + # for ind in inds: + # cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) + # ret, frame = cap.read() + # if not ret: break + # frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + # images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB"))) + # + # cap.release() + # + # if len(images) > 0: + # video_data = th.tensor(np.stack(images)) + # else: + # video_data = th.zeros(1) + # return {'video': video_data} + + def video_to_tensor(self, video_file, preprocess, sample_fp=0, start_time=None, end_time=None): + if start_time is not None or end_time is not None: + assert isinstance(start_time, int) and isinstance(end_time, int) \ + and start_time > -1 and end_time > start_time + assert sample_fp > -1 + + # Samples a frame sample_fp X frames. + cap = cv2.VideoCapture(video_file) + frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = int(cap.get(cv2.CAP_PROP_FPS)) + + total_duration = (frameCount + fps - 1) // fps + start_sec, end_sec = 0, total_duration + + if start_time is not None: + start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration + cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) + + + ret = True + images, included = [], [] + + sta_frm, end_frm = int(start_sec * fps), int((end_sec-1) * fps) + inds = np.linspace(sta_frm, end_frm, num=8, dtype=int) + # print('sta_frm, end_frm, frameCount, inds, fps, total_duration', sta_frm, end_frm, frameCount, fps, total_duration, inds) + for idx, ind in enumerate(inds): + cap.set(cv2.CAP_PROP_POS_FRAMES, ind) + ret, frame = cap.read() + if not ret: + # print(f'break {idx}') + break + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB"))) + + cap.release() + + if len(images) > 0: + video_data = th.tensor(np.stack(images)) + else: + video_data = th.zeros(1) + return {'video': video_data} + + def get_video_data(self, video_path, start_time=None, end_time=None): + image_input = self.video_to_tensor(video_path, self.transform, sample_fp=self.framerate, start_time=start_time, end_time=end_time) + return image_input + + def process_raw_data(self, raw_video_data): + tensor_size = raw_video_data.size() + tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], tensor_size[-1]) + return tensor + + def process_frame_order(self, raw_video_data, frame_order=0): + # 0: ordinary order; 1: reverse order; 2: random order. + if frame_order == 0: + pass + elif frame_order == 1: + reverse_order = np.arange(raw_video_data.size(0) - 1, -1, -1) + raw_video_data = raw_video_data[reverse_order, ...] + elif frame_order == 2: + random_order = np.arange(raw_video_data.size(0)) + np.random.shuffle(random_order) + raw_video_data = raw_video_data[random_order, ...] + + return raw_video_data + +# An ordinary video frame extractor based CV2 +RawVideoExtractor = RawVideoExtractorCV2 \ No newline at end of file diff --git a/vl_ret/retrieval.py b/vl_ret/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..64c184e87db776a8105e4e2f46beb2e10b3141b1 --- /dev/null +++ b/vl_ret/retrieval.py @@ -0,0 +1,227 @@ +import json + +import torch +import numpy as np +import random +import os +import time +import argparse +import logging + +from open_clip import get_input_dtype + +from training.distributed import is_master +from .metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim + +# torch.distributed.init_process_group(backend="nccl") +from .util import parallel_apply + + +def _run_on_single_gpu(model, + # batch_list_t, batch_list_v, + batch_sequence_output_list, batch_visual_output_list): + sim_matrix = [] + for idx1 in range(len(batch_sequence_output_list)): + # input_mask, segment_ids, *_tmp = b1 + sequence_output = batch_sequence_output_list[idx1] + each_row = [] + for idx2 in range(len(batch_visual_output_list)): + # video_mask, *_tmp = b2 + visual_output = batch_visual_output_list[idx2] + # b1b2_logits, *_tmp = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask, + # loose_type=model.loose_type) + # logging.info(f"{model.logit_scale.device}, {visual_output.device}, {sequence_output.device}") + b1b2_logits = model.logit_scale * sequence_output @ visual_output.T + # print(model.logit_scale.device, visual_output.device, sequence_output.device) + # logging.info(f"{b1b2_logits.shape}, {b1b2_logits.device}") + b1b2_logits = b1b2_logits.cpu().detach().numpy() + each_row.append(b1b2_logits) + each_row = np.concatenate(tuple(each_row), axis=-1) + sim_matrix.append(each_row) + return sim_matrix + + + +def evaluate_vl_ret(model, data, epoch, args, tb_writer=None): + input_dtype = get_input_dtype(args.precision) + if is_master(args) and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): + # print(data) + val_vl_ret_data = list(data.keys()) + # print(val_vl_ret_data) + assert len(val_vl_ret_data) == 1 + val_vl_ret_data = val_vl_ret_data[0] + test_dataloader = data[val_vl_ret_data] + device = model.device + n_gpu = torch.cuda.device_count() + logging.info(f"\nEval Epoch: {epoch}, eval Video-Text Retrieval under {val_vl_ret_data.upper()} test data") + if hasattr(model, 'module'): + model = model.module.to(device) + else: + model = model.to(device) + # ################################################################# + ## below variables are used to multi-sentences retrieval + # multi_sentence_: important tag for eval + # cut_off_points: used to tag the label when calculate the metric + # sentence_num: used to cut the sentence representation + # video_num: used to cut the video representation + # ################################################################# + multi_sentence_ = False + cut_off_points_, sentence_num_, video_num_ = [], -1, -1 + if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') and test_dataloader.dataset.multi_sentence_per_video: + multi_sentence_ = True + cut_off_points_ = test_dataloader.dataset.cut_off_points + sentence_num_ = test_dataloader.dataset.sentence_num + video_num_ = test_dataloader.dataset.video_num + cut_off_points_ = [itm - 1 for itm in cut_off_points_] + + if multi_sentence_: + logging.info("Eval under the multi-sentence per video clip setting.") + logging.info("sentence num: {}, video num: {}".format(sentence_num_, video_num_)) + + model.eval() + with torch.no_grad(): + # batch_list_t = [] + # batch_list_v = [] + batch_sequence_output_list, batch_visual_output_list = [], [] + total_video_num = 0 + + # ---------------------------- + # 1. cache the features + # ---------------------------- + for bid, batch in enumerate(test_dataloader): + # batch = tuple(t.to(device) for t in batch) + input_ids, attention_mask, _, video, _ = batch + # print(input_ids.shape, video.shape, video.dtype, end='-----') + input_ids = input_ids.squeeze().to(device) + attention_mask = attention_mask.squeeze().to(device) + # video = video.squeeze().permute(0, 2, 1, 3, 4).float().to(device) + video = video.float().to(device) + + # video = video.to(ddevice=device, dtype=input_dtype) + + + + # print(input_ids.shape, video.shape, video.dtype) + # print(input_ids.shape, video.shape) + if multi_sentence_: + # multi-sentences retrieval means: one clip has two or more descriptions. + b, *_t = video.shape + sequence_output = model.encode_text(input_ids, attention_mask) + # logging.info(f'multi: {sequence_output.shape}') + # sequence_output = model.get_sequence_output(input_ids, segment_ids, input_mask) + batch_sequence_output_list.append(sequence_output) + # batch_list_t.append((input_mask, segment_ids,)) + + s_, e_ = total_video_num, total_video_num + b + filter_inds = [itm - s_ for itm in cut_off_points_ if itm >= s_ and itm < e_] + + if len(filter_inds) > 0: + # video, video_mask = video[filter_inds, ...], video_mask[filter_inds, ...] + # print('before', video.shape) + video = video[filter_inds, ...] + # print('after', video.shape) + # visual_output = model.get_visual_output(video, video_mask) + visual_output = model.encode_image(video) + batch_visual_output_list.append(visual_output) + # batch_list_v.append((video_mask,)) + total_video_num += b + else: + sequence_output = model.encode_text(input_ids, attention_mask) + visual_output = model.encode_image(video) + # logging.info(f"{device}, {sequence_output.shape}, {visual_output.shape}") + # sequence_output, visual_output = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask) + + batch_sequence_output_list.append(sequence_output) + # batch_list_t.append((input_mask, segment_ids,)) + + batch_visual_output_list.append(visual_output) + # batch_list_v.append((video_mask,)) + + print(f"Process {val_vl_ret_data.upper()}: {bid}/{len(test_dataloader)}\r", end='') + + # ---------------------------------- + # 2. calculate the similarity + # ---------------------------------- + n_gpu = torch.cuda.device_count() + if n_gpu > 1: + device_ids = list(range(n_gpu)) + batch_t_output_splits = [] + batch_v_output_splits = [] + bacth_len = len(batch_sequence_output_list) + # print(bacth_len) + split_len = (bacth_len + n_gpu - 1) // n_gpu + for dev_id in device_ids: + s_, e_ = dev_id * split_len, (dev_id + 1) * split_len + if dev_id == 0: + + batch_t_output_splits.append(batch_sequence_output_list[s_:e_]) + batch_v_output_splits.append(batch_visual_output_list) + # print(len(batch_sequence_output_list[s_:e_]), len(batch_visual_output_list)) + else: + devc = torch.device('cuda:{}'.format(str(dev_id))) + + devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]] + batch_t_output_splits.append(devc_batch_list) + devc_batch_list = [b.to(devc) for b in batch_visual_output_list] + batch_v_output_splits.append(devc_batch_list) + # print(len(devc_batch_list), len(devc_batch_list)) + parameters_tuple_list = [( + batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids] + parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids) + sim_matrix = [] + for idx in range(len(parallel_outputs)): + sim_matrix += parallel_outputs[idx] + sim_matrix = np.concatenate(tuple(sim_matrix), axis=0) + else: + sim_matrix = _run_on_single_gpu(model, + # batch_list_t, batch_list_v, + batch_sequence_output_list, batch_visual_output_list) + sim_matrix = np.concatenate(tuple(sim_matrix), axis=0) + ##################################################################### + if multi_sentence_: + logging.info(f"{val_vl_ret_data.upper()} before reshape, sim matrix size: {sim_matrix.shape}") + cut_off_points2len_ = [itm + 1 for itm in cut_off_points_] + max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)]) + sim_matrix_new = [] + for s_, e_ in zip([0] + cut_off_points2len_[:-1], cut_off_points2len_): + sim_matrix_new.append(np.concatenate((sim_matrix[s_:e_], + np.full((max_length-e_+s_, sim_matrix.shape[1]), -np.inf)), axis=0)) + sim_matrix = np.stack(tuple(sim_matrix_new), axis=0) + logging.info(f"{val_vl_ret_data.upper()} after reshape, sim matrix size: {sim_matrix.shape}") + + tv_metrics = tensor_text_to_video_metrics(sim_matrix) + vt_metrics = compute_metrics(tensor_video_to_text_sim(sim_matrix)) + else: + logging.info(f"{val_vl_ret_data.upper()} sim matrix size: {sim_matrix.shape[0]}, {sim_matrix.shape[1]}") + tv_metrics = compute_metrics(sim_matrix) + vt_metrics = compute_metrics(sim_matrix.T) + logging.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0]))) + + logging.info(f"{val_vl_ret_data.upper()} Text-to-Video:") + logging.info('\t>>> R@1: {:.1f} - R@5: {:.1f} - R@10: {:.1f} - Median R: {:.1f} - Mean R: {:.1f}'. + format(tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR'])) + logging.info(f"{val_vl_ret_data.upper()} Video-to-Text:") + logging.info('\t>>> V2T$R@1: {:.1f} - V2T$R@5: {:.1f} - V2T$R@10: {:.1f} - V2T$Median R: {:.1f} - V2T$Mean R: {:.1f}'. + format(vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR'])) + + + if args.save_logs: + for name, val in tv_metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/vl_ret/{val_vl_ret_data}/t2v/{name}", val, epoch) + for name, val in vt_metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/vl_ret/{val_vl_ret_data}/v2t/{name}", val, epoch) + + args.vl_ret_output_dir = os.path.join(args.log_base_path, f'vl_ret/{val_vl_ret_data}') + os.makedirs(args.vl_ret_output_dir, exist_ok=True) + with open(os.path.join(args.vl_ret_output_dir, "results.jsonl"), "a+") as f: + f.write(json.dumps({'t2v': tv_metrics})) + f.write("\n") + f.write(json.dumps({'v2t': vt_metrics})) + f.write("\n") + + # R1 = tv_metrics['R1'] + # return R1 + + # torch.distributed.barrier() \ No newline at end of file diff --git a/vl_ret/tokenization_clip.py b/vl_ret/tokenization_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..3fbb56d0ef9a4dbea9a39a6c55352ef14a34898d --- /dev/null +++ b/vl_ret/tokenization_clip.py @@ -0,0 +1,145 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab = self.encoder + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + def tokenize(self, text): + tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + return tokens + + def convert_tokens_to_ids(self, tokens): + return [self.encoder[bpe_token] for bpe_token in tokens] \ No newline at end of file diff --git a/vl_ret/util.py b/vl_ret/util.py new file mode 100644 index 0000000000000000000000000000000000000000..6b11cd4ea86d93304a882bebda2f5128bec7eb4d --- /dev/null +++ b/vl_ret/util.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import threading +from torch._utils import ExceptionWrapper +import logging + +def get_a_var(obj): + if isinstance(obj, torch.Tensor): + return obj + + if isinstance(obj, list) or isinstance(obj, tuple): + for result in map(get_a_var, obj): + if isinstance(result, torch.Tensor): + return result + if isinstance(obj, dict): + for result in map(get_a_var, obj.items()): + if isinstance(result, torch.Tensor): + return result + return None + +def parallel_apply(fct, model, inputs, device_ids): + modules = nn.parallel.replicate(model, device_ids) + assert len(modules) == len(inputs) + lock = threading.Lock() + results = {} + grad_enabled = torch.is_grad_enabled() + + def _worker(i, module, input): + torch.set_grad_enabled(grad_enabled) + device = get_a_var(input).get_device() + try: + with torch.cuda.device(device): + # this also avoids accidental slicing of `input` if it is a Tensor + if not isinstance(input, (list, tuple)): + input = (input,) + output = fct(module, *input) + with lock: + results[i] = output + except Exception: + with lock: + results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) + + if len(modules) > 1: + threads = [threading.Thread(target=_worker, args=(i, module, input)) + for i, (module, input) in enumerate(zip(modules, inputs))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + _worker(0, modules[0], inputs[0]) + + outputs = [] + for i in range(len(inputs)): + output = results[i] + if isinstance(output, ExceptionWrapper): + output.reraise() + outputs.append(output) + return outputs + +def get_logger(filename=None): + logger = logging.getLogger('logger') + logger.setLevel(logging.DEBUG) + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) + if filename is not None: + handler = logging.FileHandler(filename) + handler.setLevel(logging.DEBUG) + handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) + logging.getLogger().addHandler(handler) + return logger \ No newline at end of file