■ 다중 클래스 분류 모델을 사용하는 방법을 보여준다.
▶ IssueData.cs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
using Microsoft.ML.Data; namespace TestProject { /// <summary> /// 이슈 데이터 /// </summary> public class IssueData { //////////////////////////////////////////////////////////////////////////////////////////////////// Property ////////////////////////////////////////////////////////////////////////////////////////// Public #region ID - ID /// <summary> /// ID /// </summary> [LoadColumn(0)] public string ID { get; set; } #endregion #region 영역 - Area /// <summary> /// 영역 /// </summary> [LoadColumn(1)] public string Area { get; set; } #endregion #region 제목 - Title /// <summary> /// 제목 /// </summary> [LoadColumn(2)] public string Title { get; set; } #endregion #region 설명 - Description /// <summary> /// 설명 /// </summary> [LoadColumn(3)] public string Description { get; set; } #endregion } } |
▶ IssuePrediction.cs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
using Microsoft.ML.Data; namespace TestProject { /// <summary> /// 이슈 예측 /// </summary> public class IssuePrediction { //////////////////////////////////////////////////////////////////////////////////////////////////// Property ////////////////////////////////////////////////////////////////////////////////////////// Public #region 영역 - Area /// <summary> /// 영역 /// </summary> [ColumnName("PredictedLabel")] public string Area; #endregion } } |
▶ Program.cs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
using System; using System.IO; using Microsoft.ML; using Microsoft.ML.Data; namespace TestProject { /// <summary> /// 프로그램 /// </summary> class Program { //////////////////////////////////////////////////////////////////////////////////////////////////// Field ////////////////////////////////////////////////////////////////////////////////////////// Static //////////////////////////////////////////////////////////////////////////////// Private #region Field /// <summary> /// 애플리케이션 경로 /// </summary> private static string _applicationPath => Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]); /// <summary> /// 훈련 데이터 파일 경로 /// </summary> private static string _trainingDataFilePath => Path.Combine(_applicationPath, "Data", "issues_train.tsv"); /// <summary> /// 테스트 데이터 파일 경로 /// </summary> private static string _testDataFilePath => Path.Combine(_applicationPath, "Data", "issues_test.tsv"); /// <summary> /// 모델 파일 경로 /// </summary> private static string _modelFilePath => Path.Combine(_applicationPath, "Data", "model.zip"); /// <summary> /// 모델 컨텍스트 /// </summary> private static MLContext _context; /// <summary> /// 예측 엔진 /// </summary> private static PredictionEngine<IssueData, IssuePrediction> _predictionEngine; /// <summary> /// 모델 /// </summary> private static ITransformer _model; /// <summary> /// 훈련 데이터 뷰 /// </summary> private static IDataView _trainingDataView; #endregion //////////////////////////////////////////////////////////////////////////////////////////////////// Method ////////////////////////////////////////////////////////////////////////////////////////// Static //////////////////////////////////////////////////////////////////////////////// Private #region 프로그램 시작하기 - Main() /// <summary> /// 프로그램 시작하기 /// </summary> private static void Main() { Console.WriteLine("BEGIN MAIN FUNCTION"); _context = new MLContext(seed: 0); Console.WriteLine("BEGIN SET TRAINING DATA VIEW"); _trainingDataView = _context.Data.LoadFromTextFile<IssueData>(_trainingDataFilePath, hasHeader: true); Console.WriteLine("END SET TRAINING DATA VIEW"); Console.WriteLine("BEGIN SET PIPE LINE"); var pipeline = GetPipeLine(); Console.WriteLine("END SET PIPE LINE"); _model = GetModel(_trainingDataView, pipeline); Console.WriteLine("BEGIN PREDICT SINGLE ITEM"); _predictionEngine = _context.Model.CreatePredictionEngine<IssueData, IssuePrediction>(_model); IssueData issueData = new IssueData() { Title = "WebSockets communication is slow in my machine", Description = "The WebSockets communication used under the covers by SignalR looks like is going slow in my development machine.." }; IssuePrediction issuePrediction = _predictionEngine.Predict(issueData); Console.WriteLine("--------------------------------------------------"); Console.WriteLine($"AREA : {issuePrediction.Area}"); Console.WriteLine("--------------------------------------------------"); Evaluate(_trainingDataView.Schema); SaveModel(_context, _trainingDataView.Schema, _model); PredictIssue(); Console.WriteLine("END MAIN FUNCTION"); } #endregion #region 파이프 라인 구하기 - GetPipeLine() /// <summary> /// 파이프 라인 구하기 /// </summary> /// <returns>파이프 라인</returns> private static IEstimator<ITransformer> GetPipeLine() { Console.WriteLine("BEGIN GET PIPE LINE FUNCTION"); var pipeline = _context.Transforms.Conversion.MapValueToKey(inputColumnName : "Area", outputColumnName : "Label") .Append(_context.Transforms.Text.FeaturizeText(inputColumnName : "Title" , outputColumnName : "TitleFeaturized" )) .Append(_context.Transforms.Text.FeaturizeText(inputColumnName : "Description", outputColumnName : "DescriptionFeaturized")) .Append(_context.Transforms.Concatenate("Features", "TitleFeaturized", "DescriptionFeaturized")) .AppendCacheCheckpoint(_context) .Append(_context.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "Features")) .Append(_context.Transforms.Conversion.MapKeyToValue("PredictedLabel")); Console.WriteLine("END GET PIPE LINE FUNCTION"); return pipeline; } #endregion #region 모델 구하기 - GetModel(trainingDataView, pipeline) /// <summary> /// 모델 구하기 /// </summary> /// <param name="trainingDataView">훈련 데이터 뷰</param> /// <param name="pipeline">파이프 라인</param> /// <returns>모델</returns> private static ITransformer GetModel(IDataView trainingDataView, IEstimator<ITransformer> pipeline) { Console.WriteLine("BEGIN GET MODEL FUNCTION"); var model = pipeline.Fit(trainingDataView); Console.WriteLine("END GET MODEL FUNCTION"); return model; } #endregion #region 평가하기 - Evaluate(trainingDataViewSchema) /// <summary> /// 평가하기 /// </summary> /// <param name="trainingDataViewSchema">훈련 데이터 뷰 스키마</param> private static void Evaluate(DataViewSchema trainingDataViewSchema) { Console.WriteLine("BEGIN EVALUATE FUNCTION"); IDataView testDataView = _context.Data.LoadFromTextFile<IssueData>(_testDataFilePath, hasHeader : true); MulticlassClassificationMetrics metrics = _context.MulticlassClassification.Evaluate(_model.Transform(testDataView)); Console.WriteLine("--------------------------------------------------" ); Console.WriteLine("METRICS FOR MULTI-CLASS CLASSIFICATION MODEL - TEST DATA"); Console.WriteLine("--------------------------------------------------" ); Console.WriteLine($"MICRO ACCURACY : {metrics.MicroAccuracy:0.###}" ); Console.WriteLine($"MACRO ACCURACY : {metrics.MacroAccuracy:0.###}" ); Console.WriteLine($"LOG LOSS : {metrics.LogLoss:#.###}" ); Console.WriteLine($"LOG LOSS REDUCTION : {metrics.LogLossReduction:#.###}" ); Console.WriteLine("--------------------------------------------------" ); Console.WriteLine("END EVALUATE FUNCTION"); } #endregion #region 모델 저장하기 - SaveModel(context, trainingDataViewSchema, model) /// <summary> /// 모델 저장하기 /// </summary> /// <param name="context">ML 컨텍스트</param> /// <param name="trainingDataViewSchema">훈련 데이터 뷰 스키마</param> /// <param name="model">모델</param> private static void SaveModel(MLContext context, DataViewSchema trainingDataViewSchema, ITransformer model) { Console.WriteLine("BEGIN SAVE MODEL FUNCTION"); context.Model.Save(model, trainingDataViewSchema, _modelFilePath); Console.WriteLine("END SAVE MODEL FUNCTION"); } #endregion #region 이슈 예측하기 - PredictIssue() /// <summary> /// 이슈 예측하기 /// </summary> private static void PredictIssue() { Console.WriteLine("BEGIN PREDICT ISSUE FUNCTION"); ITransformer model = _context.Model.Load(_modelFilePath, out var inputSchema); IssueData issueData = new IssueData() { Title = "Entity Framework crashes", Description = "When connecting to the database, EF is crashing" }; _predictionEngine = _context.Model.CreatePredictionEngine<IssueData, IssuePrediction>(model); IssuePrediction issuePrediction = _predictionEngine.Predict(issueData); Console.WriteLine("--------------------------------------------------"); Console.WriteLine($"AREA : {issuePrediction.Area}"); Console.WriteLine("--------------------------------------------------"); Console.WriteLine("END PREDICT ISSUE FUNCTION"); } #endregion } } |