■ OpenCLIPEmbeddings 클래스를 사용해 이미지 텍스트와 이미지 특징간 코사인 유사도를 구하는 방법을 보여준다.
▶ main.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import base64 import os import numpy as np import matplotlib.pyplot as plt from datetime import datetime from langchain.schema.messages import HumanMessage from langchain_ollama import ChatOllama from langchain_experimental.open_clip import OpenCLIPEmbeddings from PIL import Image def printMessage(message, *argumentTuple): timeStamp = datetime.now().strftime("[%H:%M:%S]") if argumentTuple: finalMessage = message % argumentTuple else: finalMessage = message print(f"{timeStamp} {finalMessage}") def getBASE64StringFromFile(filePath): with open(filePath, "rb") as bufferedReader: imageBytes = bufferedReader.read() base64Bytes = base64.b64encode(imageBytes) base64String = base64Bytes.decode("utf-8") return base64String def getImageDescription(chatOllama, imageFilePath): imageBASE64String = getBASE64StringFromFile(imageFilePath) humanMessage = HumanMessage( content = [ {"type" : "text" , "text" : "Please describe the image in detail." }, {"type" : "image_url", "image_url" : {"url" : f"data:image/jpeg;base64,{imageBASE64String}"}} ] ) responseAIMessage = chatOllama.invoke([humanMessage]) return responseAIMessage.content printMessage("START GET IMAGE FILE PATH LIST") imageFilePathList = sorted([os.path.join("temp", imageFileName) for imageFileName in os.listdir("temp") if imageFileName.endswith(".jpg")]) printMessage("END GET IMAGE FILE PATH LIST") print() printMessage("START CREATE CHATOLLAMA") chatOllama = ChatOllama(model = "llava:latest", temperature = 0) printMessage("END CREATE CHATOLLAMA") print() printMessage("START GET IMAGE DESCRIPTION") print("-" * 50) imageDescriptionDictionary = dict() for imageFilePath in imageFilePathList: imageDescription = getImageDescription(chatOllama, imageFilePath) imageDescriptionDictionary[imageFilePath] = imageDescription print(f" {imageFilePath}") print("-" * 50) printMessage("END GET IMAGE DESCRIPTION") print() printMessage("START CREATE OPENCLIPEMBEDDINGS") openCLIPEmbeddings = OpenCLIPEmbeddings(model_name = "ViT-H-14-378-quickgelu", checkpoint = "dfn5b") printMessage("END CREATE OPENCLIPEMBEDDINGS") print() printMessage("START GET IMAGE FEATURE LIST LIST") imageFeatureListList = openCLIPEmbeddings.embed_image(imageFilePathList) printMessage("END GET IMAGE FEATURE LIST LIST") print() printMessage("START GET IMAGE DESCRIPTION LIST") imageDescriptionList = [] for i, (key, value) in enumerate(imageDescriptionDictionary.items()): imageDescriptionList.append(value) printMessage("END GET IMAGE DESCRIPTION LIST") print() printMessage("START GET DESCRIPTION FEATURE LIST") imageDescriptionFeatureListList = openCLIPEmbeddings.embed_documents(["This is " + imageDescription for imageDescription in imageDescriptionList]) printMessage("END GET DESCRIPTION FEATURE LIST") print() printMessage("START CALCULATE SIMILARITY ND ARRAY") imageFeatureNDArray = np.array(imageFeatureListList) imageDescriptionFeatureNDArray = np.array(imageDescriptionFeatureListList) similarityNDArray = np.matmul(imageDescriptionFeatureNDArray, imageFeatureNDArray.T) printMessage("END CALCULATE SIMILARITY ND ARRAY") print() printMessage("START GET SOURCE IMAGE LIST") sourceImageList = [] for _, imageFilePath in enumerate(imageFilePathList): image = Image.open(imageFilePath).convert("RGB") sourceImageList.append(image) printMessage("END GET SOURCE IMAGE LIST") count = len(imageDescriptionDictionary) plt.figure(figsize = (20, 14)) plt.imshow(similarityNDArray, vmin = 0.1, vmax = 0.3, cmap = "coolwarm") plt.colorbar() for i, imageDescription in enumerate(imageDescriptionList): imageDescriptionList[i] = imageDescription[:50] plt.yticks(range(count), imageDescriptionList, fontsize = 10) plt.xticks([]) for i, image in enumerate(sourceImageList): plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin = "lower") for x in range(similarityNDArray.shape[1]): for y in range(similarityNDArray.shape[0]): plt.text(x, y, f"{similarityNDArray[y, x]:.2f}", ha = "center", va = "center", size = 12) for side in ["left", "top", "right", "bottom"]: plt.gca().spines[side].set_visible(False) plt.xlim([-0.5, count - 0.5]) plt.ylim([count + 0.5, -2]) plt.title("Cosine similarity between text and image features", size = 12) plt.tight_layout() plt.show() |
▶ requirements.txt
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
aiohappyeyeballs==2.4.4 aiohttp==3.11.11 aiosignal==1.3.2 annotated-types==0.7.0 anyio==4.8.0 attrs==24.3.0 certifi==2024.12.14 charset-normalizer==3.4.1 colorama==0.4.6 contourpy==1.3.1 cycler==0.12.1 dataclasses-json==0.6.7 filelock==3.16.1 fonttools==4.55.3 frozenlist==1.5.0 fsspec==2024.12.0 ftfy==6.3.1 greenlet==3.1.1 h11==0.14.0 httpcore==1.0.7 httpx==0.27.2 httpx-sse==0.4.0 huggingface-hub==0.27.1 idna==3.10 Jinja2==3.1.5 jsonpatch==1.33 jsonpointer==3.0.0 kiwisolver==1.4.8 langchain==0.3.14 langchain-community==0.3.14 langchain-core==0.3.29 langchain-experimental==0.3.4 langchain-ollama==0.2.2 langchain-text-splitters==0.3.5 langsmith==0.2.10 MarkupSafe==3.0.2 marshmallow==3.25.1 matplotlib==3.10.0 mpmath==1.3.0 multidict==6.1.0 mypy-extensions==1.0.0 networkx==3.4.2 numpy==2.2.1 ollama==0.4.6 open_clip_torch==2.30.0 orjson==3.10.14 packaging==24.2 pillow==11.1.0 propcache==0.2.1 pydantic==2.10.5 pydantic-settings==2.7.1 pydantic_core==2.27.2 pyparsing==3.2.1 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 PyYAML==6.0.2 regex==2024.11.6 requests==2.32.3 requests-toolbelt==1.0.0 safetensors==0.5.2 setuptools==75.8.0 six==1.17.0 sniffio==1.3.1 SQLAlchemy==2.0.37 sympy==1.13.1 tenacity==9.0.0 timm==1.0.13 torch==2.5.1 torchvision==0.20.1 tqdm==4.67.1 typing-inspect==0.9.0 typing_extensions==4.12.2 urllib3==2.3.0 wcwidth==0.2.13 yarl==1.18.3 |
※ pip install langchain_ollama langchain_experimental open_clip_torch matplotlib 명령을 실행했다.