■ ChatOllama 클래스에서 토큰 수를 계산해서 재귀적 요약으로 PDF 문서를 요약하는 방법을 보여준다. (llama3.2:3b 모델)
▶ main.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from dotenv import load_dotenv from datetime import datetime from langchain_ollama import ChatOllama from transformers import AutoTokenizer from langchain.prompts import PromptTemplate from langchain.chains.summarize import load_summarize_chain from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter load_dotenv() def printMessage(message, *argumentTuple): timeStamp = datetime.now().strftime("[%H:%M:%S]") if argumentTuple: finalMessage = message % argumentTuple else: finalMessage = message print(f"{timeStamp} {finalMessage}") class PDFSummarizer: def __init__(self, modelName = "llama3.2:3b", maximumTokenCount = 131072): self.llm = ChatOllama(model = modelName, temperature = 0, max_tokens = maximumTokenCount) self.maximumTokenCount = maximumTokenCount self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer", legacy = False) self.summarizeDocumentPromptTemplateString = """아래 텍스트의 핵심 내용을 요약해주세요 : {text} 핵심 요약 :""" self.summarizeDocumentPromptTemplate = PromptTemplate(template = self.summarizeDocumentPromptTemplateString, input_variables = ["text"]) self.summaryPromptTemplateString = """아래는 긴 문서의 각 부분을 요약한 내용입니다. 이 요약들을 하나의 일관된 요약으로 만들어주세요: {text} 최종 요약 :""" self.summaryPromptTemplate = PromptTemplate(template = self.summaryPromptTemplateString, input_variables = ["text"]) self.mapReduceDocumentsChain = load_summarize_chain( self.llm, chain_type = "map_reduce", map_prompt = self.summarizeDocumentPromptTemplate, combine_prompt = self.summaryPromptTemplate, verbose = False ) def getTokenCount(self, source): printMessage("START GET TOKEN COUNT FUNCTION") if isinstance(source, dict) and "output_text" in source: text = source["output_text"] text = str(source) tokenList = self.tokenizer.encode(text, add_special_tokens = False) tokenCount = len(tokenList) printMessage(f" TOKEN COUNT : {tokenCount}") printMessage("END GET TOKEN COUNT FUNCTION") return tokenCount def loadDocument(self, filePath): printMessage("START LOAD DOCUMENT FUNCTION") pyPDFLoader = PyPDFLoader(filePath) documentList = pyPDFLoader.load() recursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200) splitDocumentList = recursiveCharacterTextSplitter.split_documents(documentList) printMessage(f" DOCUMENT LIST LENGTH : {len(documentList )}") printMessage(f" SPLIT DOCUMENT LIST LENGTH : {len(splitDocumentList)}") printMessage("START LOAD DOCUMENT FUNCTION") return splitDocumentList def summarizeDocumentList(self, documentList): printMessage("START SUMMARIZE DOCUMENT LIST FUNCTION") responseDictionary = self.mapReduceDocumentsChain.invoke(documentList) printMessage("END SUMMARIZE DOCUMENT LIST FUNCTION") return responseDictionary def summarizeRecursively(self, text, level = 0): printMessage(f"{' ' * level}START SUMMARIZE RECURSIVELY FUNCTION") tokenCount = self.getTokenCount(text) if tokenCount <= self.maximumTokenCount: printMessage(f"{' ' * level}RETURN TEXT IF TOKEN COUNT <= SELF.MAXIMUM TOKEN COUNT") if isinstance(text, str): return text elif isinstance(text, dict): return text["output_text"] else: return None recursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200) splitDocumentList = recursiveCharacterTextSplitter.create_documents([text]) splitDocumentListLength = len(splitDocumentList) summaryList = [] for index, splitDocument in enumerate(splitDocumentList): mapReduceDocumentsChain = load_summarize_chain(self.llm, chain_type = "stuff", prompt = self.summaryPromptTemplate) if index == 0: print(type(mapReduceDocumentsChain)) summary = mapReduceDocumentsChain.run([splitDocument]) printMessage(f"{' ' * level} {index + 1}/{splitDocumentListLength} SUMMARY : {len(summary)}") summaryList.append(summary) totalSummary = "\n\n".join(summaryList) finalSummaryDictionary = self.summarizeRecursively(totalSummary, level + 1) printMessage(f"{' ' * level}END SUMMARIZE RECURSIVELY FUNCTION") return finalSummaryDictionary def summarize(self, filePath): printMessage("START SUMMARIZE FUNCTION") documentList = self.loadDocument(filePath) initialSummaryDictionary = self.summarizeDocumentList(documentList) finalSummary = self.summarizeRecursively(initialSummaryDictionary) printMessage("END SUMMARIZE FUNCTION") return finalSummary pdfSummarizer = PDFSummarizer() summary = pdfSummarizer.summarize("SPRI_AI_Brief_2023년12월호_F.pdf") printMessage("최종 요약 :") print("-" * 50) print(summary) """ [00:18:20] START SUMMARIZE FUNCTION [00:18:20] START LOAD DOCUMENT FUNCTION [00:18:21] DOCUMENT LIST LENGTH : 23 [00:18:21] SPLIT DOCUMENT LIST LENGTH : 60 [00:18:21] START LOAD DOCUMENT FUNCTION [00:18:21] START SUMMARIZE DOCUMENT LIST FUNCTION Token indices sequence length is longer than the specified maximum sequence length for this model (25209 > 1024). Running this sequence through the model will result in indexing errors [00:20:33] END SUMMARIZE DOCUMENT LIST FUNCTION [00:20:33] START SUMMARIZE RECURSIVELY FUNCTION [00:20:33] START GET TOKEN COUNT FUNCTION Token indices sequence length is longer than the specified maximum sequence length for this model (42435 > 2048). Running this sequence through the model will result in indexing errors [00:20:33] TOKEN COUNT : 42435 [00:20:33] END GET TOKEN COUNT FUNCTION [00:20:33] RETURN TEXT IF TOKEN COUNT <= SELF.MAXIMUM TOKEN COUNT [00:20:33] END SUMMARIZE FUNCTION [00:20:33] 최종 요약 : -------------------------------------------------- AI 산업의 현재 동향을 요약하면 다음과 같습니다. * 미국과 영국 등 28개국이 AI 안전성에 대한 협상을 시작했습니다. 이 협력은 AI 안전 보장을 위한 강화와 사회적 위험을 완화하는 데 중요한 역할을 할 것입니다. * 미국 프런티어 모델 포럼은 1,000만 달러 규모의 AI 안전 기금을 조성했습니다. 구글은 앤드로픽에 20억 달러 투자로 generation AI 협력 강화를 시작했습니다. * 코히어는 데이터 투명성을 확보하기 위해 데이터 출처 탐색기를 공개했습니다. * 알리바바 클라우드는 최신 LLM '통이치엔원 2.0'를 공개했습니다. 삼성전자는 자체 개발 생성 AI '삼성 가우스'를 공개했습니다. * 미국 정부는 AI 사용을 촉진하고 맞춤형 개인교습 등 학교 내 AI 교육 도구를 개발하여 AI로 인한 근로자 피해를 완화하고 이점을 극대화하는 것을 목표로 한다. G7은 AI 국제 행동강령을 발표했습니다. * 영국 정부는 새로운 AI Safety Institute를 설립하여 AI 안전 연구를 진행하고, 10년간 공공자금을 투자해 연구를 지원할 계획이다. """ |
▶ requirements.txt
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
aiohappyeyeballs==2.4.4 aiohttp==3.11.11 aiosignal==1.3.2 annotated-types==0.7.0 anyio==4.8.0 async-timeout==4.0.3 attrs==24.3.0 certifi==2024.12.14 charset-normalizer==3.4.1 dataclasses-json==0.6.7 exceptiongroup==1.2.2 filelock==3.16.1 frozenlist==1.5.0 fsspec==2024.12.0 greenlet==3.1.1 h11==0.14.0 httpcore==1.0.7 httpx==0.27.2 httpx-sse==0.4.0 huggingface-hub==0.27.1 idna==3.10 Jinja2==3.1.5 jsonpatch==1.33 jsonpointer==3.0.0 langchain==0.3.14 langchain-community==0.3.14 langchain-core==0.3.30 langchain-ollama==0.2.2 langchain-text-splitters==0.3.5 langsmith==0.2.11 MarkupSafe==3.0.2 marshmallow==3.25.1 mpmath==1.3.0 multidict==6.1.0 mypy-extensions==1.0.0 networkx==3.4.2 numpy==1.26.4 nvidia-cublas-cu12==12.4.5.8 nvidia-cuda-cupti-cu12==12.4.127 nvidia-cuda-nvrtc-cu12==12.4.127 nvidia-cuda-runtime-cu12==12.4.127 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.2.1.3 nvidia-curand-cu12==10.3.5.147 nvidia-cusolver-cu12==11.6.1.9 nvidia-cusparse-cu12==12.3.1.170 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 ollama==0.4.6 orjson==3.10.14 packaging==24.2 propcache==0.2.1 pydantic==2.10.5 pydantic-settings==2.7.1 pydantic_core==2.27.2 pypdf==5.1.0 python-dotenv==1.0.1 PyYAML==6.0.2 regex==2024.11.6 requests==2.32.3 requests-toolbelt==1.0.0 safetensors==0.5.2 sniffio==1.3.1 SQLAlchemy==2.0.37 sympy==1.13.1 tenacity==9.0.0 tokenizers==0.21.0 torch==2.5.1 tqdm==4.67.1 transformers==4.48.0 triton==3.1.0 typing-inspect==0.9.0 typing_extensions==4.12.2 urllib3==2.3.0 yarl==1.18.3 |
※ pip install python-dotenv langchain-community langchain-ollama transformers pypdf torch 명령을 실행했다.