■ LLM 클래스를 사용해 커스텀 모델을 만드는 방법을 보여준다.
▶ main.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import asyncio from langchain_core.language_models.llms import LLM from typing import Optional from typing import List from langchain_core.callbacks.manager import CallbackManagerForLLMRun from typing import Any from typing import Iterator from langchain_core.outputs import GenerationChunk from typing import Dict from langchain_core.prompts import ChatPromptTemplate class CustomLLM(LLM): """A custom chat model that echoes the first `n` characters of the input. When contributing an implementation to LangChain, carefully document the model including the initialization parameters, include an example of how to initialize the model and include any relevant links to the underlying models documentation or API. Example: .. code-block:: python customLLM = CustomLLM(n = 2) result = model.invoke([HumanMessage(content = "hello")]) result = model.batch([[HumanMessage(content = "hello")], [HumanMessage(content = "world")]]) """ n : int """The number of characters from the last message of the prompt to be echoed.""" def _call( self, prompt : str, stop : Optional[List[str]] = None, run_manager : Optional[CallbackManagerForLLMRun] = None, **kwargs : Any, ) -> str: """Run the LLM on the given input. Override this method to implement the LLM logic. Args: prompt: The prompt to generate from. stop: Stop words to use when generating. Model output is cut off at the first occurrence of any of the stop substrings. If stop tokens are not supported consider raising NotImplementedError. run_manager: Callback manager for the run. **kwargs: Arbitrary additional keyword arguments. These are usually passed to the model provider API call. Returns: The model output as a string. Actual completions SHOULD NOT include the prompt. """ if stop is not None: raise ValueError("stop kwargs are not permitted.") return prompt[:self.n] def _stream( self, prompt : str, stop : Optional[List[str]] = None, run_manager : Optional[CallbackManagerForLLMRun] = None, **kwargs : Any, ) -> Iterator[GenerationChunk]: """Stream the LLM on the given prompt. This method should be overridden by subclasses that support streaming. If not implemented, the default behavior of calls to stream will be to fallback to the non-streaming version of the model and return the output as a single chunk. Args: prompt: The prompt to generate from. stop: Stop words to use when generating. Model output is cut off at the first occurrence of any of these substrings. run_manager: Callback manager for the run. **kwargs: Arbitrary additional keyword arguments. These are usually passed to the model provider API call. Returns: An iterator of GenerationChunks. """ for character in prompt[:self.n]: generationChunk = GenerationChunk(text = character) if run_manager: run_manager.on_llm_new_token(generationChunk.text, chunk = generationChunk) yield generationChunk @property def _identifying_params(self) -> Dict[str, Any]: """Return a dictionary of identifying parameters.""" return { # The model name allows users to specify custom token counting # rules in LLM monitoring applications (e.g., in LangSmith users # can provide per token pricing for their model and monitor # costs for the given LLM.) "model_name": "CustomChatModel", } @property def _llm_type(self) -> str: """Get the type of language model used by this chat model. Used for logging purposes only.""" return "custom" customLLM = CustomLLM(n = 5) responseString = customLLM.invoke("This is a foobar thing") print(responseString) print("-" * 50) responseStringList = customLLM.batch(["woof woof woof", "meow meow meow"]) print(responseStringList) print("-" * 50) async def main(): responseString = await customLLM.ainvoke("world") print(responseString) print("-" * 50) responseStringList = await customLLM.abatch(["woof woof woof", "meow meow meow"]) print(responseStringList) print("-" * 50) async for stringChunk in customLLM.astream("hello"): print(stringChunk, end = "|", flush = True) print() print("-" * 50) chatPromptTemplate = ChatPromptTemplate.from_messages([("system", "you are a bot"), ("human", "{input}")]) runnableSequence = chatPromptTemplate | customLLM index = 0 async for eventDictionary in runnableSequence.astream_events({"input" : "hello there!"}, version = "v2"): print(eventDictionary) index += 1 if index > 7: break print("-" * 50) asyncio.run(main()) """ This --------------------------------------------------- ['woof ', 'meow '] --------------------------------------------------- world --------------------------------------------------- ['woof ', 'meow '] --------------------------------------------------- h|e|l|l|o| --------------------------------------------------- {'event': 'on_chain_start', 'data': {'input': {'input': 'hello there!'}}, 'name': 'RunnableSequence', 'tags': [], 'run_id': 'b74d2aa5-6202-4792-9d24-bec5c1abd446', 'metadata': {}, 'parent_ids': []} {'event': 'on_prompt_start', 'data': {'input': {'input': 'hello there!'}}, 'name': 'ChatPromptTemplate', 'tags': ['seq:step:1'], 'run_id': '8303741b-eb65-4ba3-b415-4235c0dff8db', 'metadata': {}, 'parent_ids': ['b74d2aa5-6202-4792-9d24-bec5c1abd446']} {'event': 'on_prompt_end', 'data': {'output': ChatPromptValue(messages=[SystemMessage(content='you are a bot', additional_kwargs={}, response_metadata={}), HumanMessage(content='hello there!', additional_kwargs={}, response_metadata={})]), 'input': {'input': 'hello there!'}}, 'run_id': '8303741b-eb65-4ba3-b415-4235c0dff8db', 'name': 'ChatPromptTemplate', 'tags': ['seq:step:1'], 'metadata': {}, 'parent_ids': ['b74d2aa5-6202-4792-9d24-bec5c1abd446']} {'event': 'on_llm_start', 'data': {'input': {'prompts': ['System: you are a bot\nHuman: hello there!']}}, 'name': 'CustomLLM', 'tags': ['seq:step:2'], 'run_id': '1f28dce6-8326-443f-8f09-2ca6967d6171', 'metadata': {'ls_provider': 'custom', 'ls_model_type': 'llm'}, 'parent_ids': ['b74d2aa5-6202-4792-9d24-bec5c1abd446']} {'event': 'on_llm_stream', 'data': {'chunk': GenerationChunk(text='S')}, 'run_id': '1f28dce6-8326-443f-8f09-2ca6967d6171', 'name': 'CustomLLM', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'custom', 'ls_model_type': 'llm'}, 'parent_ids': ['b74d2aa5-6202-4792-9d24-bec5c1abd446']} {'event': 'on_chain_stream', 'run_id': 'b74d2aa5-6202-4792-9d24-bec5c1abd446', 'name': 'RunnableSequence', 'tags': [], 'metadata': {}, 'data': {'chunk': 'S'}, 'parent_ids': []} {'event': 'on_llm_stream', 'data': {'chunk': GenerationChunk(text='y')}, 'run_id': '1f28dce6-8326-443f-8f09-2ca6967d6171', 'name': 'CustomLLM', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'custom', 'ls_model_type': 'llm'}, 'parent_ids': ['b74d2aa5-6202-4792-9d24-bec5c1abd446']} {'event': 'on_chain_stream', 'run_id': 'b74d2aa5-6202-4792-9d24-bec5c1abd446', 'name': 'RunnableSequence', 'tags': [], 'metadata': {}, 'data': {'chunk': 'y'}, 'parent_ids': []} --------------------------------------------------- """ |
▶ requirements.txt
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
aiohappyeyeballs==2.4.3 aiohttp==3.11.7 aiosignal==1.3.1 annotated-types==0.7.0 anyio==4.6.2.post1 attrs==24.2.0 certifi==2024.8.30 charset-normalizer==3.4.0 frozenlist==1.5.0 greenlet==3.1.1 h11==0.14.0 httpcore==1.0.7 httpx==0.27.2 idna==3.10 jsonpatch==1.33 jsonpointer==3.0.0 langchain==0.3.7 langchain-core==0.3.19 langchain-text-splitters==0.3.2 langsmith==0.1.144 multidict==6.1.0 numpy==1.26.4 orjson==3.10.11 packaging==24.2 propcache==0.2.0 pydantic==2.10.1 pydantic_core==2.27.1 PyYAML==6.0.2 requests==2.32.3 requests-toolbelt==1.0.0 sniffio==1.3.1 SQLAlchemy==2.0.36 tenacity==9.0.0 typing_extensions==4.12.2 urllib3==2.2.3 yarl==1.18.0 |
※ pip install langchain 명령을 실행했다.