Core ML框架詳細解析(二十一) —— 在iOS設備上使用Style Transfer創建一個自定義圖像濾波器(二) 版本記錄 前言 源碼 後記

版本記錄

版本號 時間
V1.0 2022.09.11 星期日

前言

目前世界上科技界的所有大佬一致認爲人工智能是下一代科技革命,蘋果作爲科技界的巨頭,當然也會緊跟新的科技革命的步伐,其中ios API 就新出了一個框架Core ML。ML是Machine Learning的縮寫,也就是機器學習,這正是現在很火的一個技術,它也是人工智能最核心的內容。感興趣的可以看我寫的下面幾篇。
1. Core ML框架詳細解析(一) —— Core ML基本概覽
2. Core ML框架詳細解析(二) —— 獲取模型並集成到APP中
3. Core ML框架詳細解析(三) —— 利用Vision和Core ML對圖像進行分類
4. Core ML框架詳細解析(四) —— 將訓練模型轉化爲Core ML
5. Core ML框架詳細解析(五) —— 一個Core ML簡單示例(一)
6. Core ML框架詳細解析(六) —— 一個Core ML簡單示例(二)
7. Core ML框架詳細解析(七) —— 減少Core ML應用程序的大小(一)
8. Core ML框架詳細解析(八) —— 在用戶設備上下載和編譯模型(一)
9. Core ML框架詳細解析(九) —— 用一系列輸入進行預測(一)
10. Core ML框架詳細解析(十) —— 集成自定義圖層(一)
11. Core ML框架詳細解析(十一) —— 創建自定義圖層(一)
12. Core ML框架詳細解析(十二) —— 用scikit-learn開始機器學習(一)
13. Core ML框架詳細解析(十三) —— 使用Keras和Core ML開始機器學習(一)
14. Core ML框架詳細解析(十四) —— 使用Keras和Core ML開始機器學習(二)
15. Core ML框架詳細解析(十五) —— 機器學習:分類(一)
16. Core ML框架詳細解析(十六) —— 人工智能和IBM Watson Services(一)
17. Core ML框架詳細解析(十七) —— Core ML 和 Vision簡單示例(一)
18. Core ML框架詳細解析(十八) —— 基於Core ML 和 Vision的設備上的訓練(一)
19. Core ML框架詳細解析(十九) —— 基於Core ML 和 Vision的設備上的訓練(二)
20. Core ML框架詳細解析(二十) —— 在iOS設備上使用Style Transfer創建一個自定義圖像濾波器(一)

源碼

1. Swift

首先看下工程組織結構

下面就是正文了

1. AppMain.swift
import SwiftUI

@main
struct AppMain: App {
  var body: some Scene {
    WindowGroup {
      ContentView()
    }
  }
}
2. ContentView.swift
import SwiftUI

struct AlertMessage: Identifiable {
  let id = UUID()
  var title: Text
  var message: Text
  var actionButton: Alert.Button?
  var cancelButton: Alert.Button = .default(Text("OK"))
}

struct PickerInfo: Identifiable {
  let id = UUID()
  let picker: PickerView
}

struct ContentView: View {
  @State private var image: UIImage?
  @State private var styleImage: UIImage?
  @State private var stylizedImage: UIImage?
  @State private var processing = false
  @State private var showAlertMessage: AlertMessage?
  @State private var showImagePicker: PickerInfo?

  var body: some View {
    VStack {
      Text("PETRA")
        .font(.title)
      Spacer()
      Button(action: {
        if self.stylizedImage != nil {
          self.showAlertMessage = .init(
            title: Text("Choose new image?"),
            message: Text("This will clear the existing image!"),
            actionButton: .destructive(
              Text("Continue")) {
                self.stylizedImage = nil
                self.image = nil
                self.showImagePicker = PickerInfo(picker: PickerView(selectedImage: self.$image))
            },
            cancelButton: .cancel(Text("Cancel")))
        } else {
          self.showImagePicker = PickerInfo(picker: PickerView(selectedImage: self.$image))
        }
      }, label: {
        if let anImage = self.stylizedImage ?? self.image {
          Image(uiImage: anImage)
            .resizable()
            .scaledToFit()
            .aspectRatio(contentMode: ContentMode.fit)
            .border(.blue, width: 3)
        } else {
          Text("Choose a Pet Image")
            .font(.callout)
            .foregroundColor(.blue)
            .padding()
            .cornerRadius(10)
            .border(.blue, width: 3)
        }
      })
      Spacer()
      Text("Choose Style to Apply")
      Button(action: {
        self.showImagePicker = PickerInfo(picker: PickerView(selectedImage: self.$styleImage))
      }, label: {
        Image(uiImage: styleImage ?? UIImage(named: Constants.Path.presetStyle1) ?? UIImage())
          .resizable()
          .frame(width: 100, height: 100, alignment: .center)
          .scaledToFit()
          .aspectRatio(contentMode: ContentMode.fit)
          .cornerRadius(10)
          .border(.blue, width: 3)
      })
      Button(action: {
        guard let petImage = image, let styleImage = styleImage ?? UIImage(named: Constants.Path.presetStyle1) else {
          self.showAlertMessage = .init(
            title: Text("Error"),
            message: Text("You need to choose a Pet photo before applying the style!"),
            actionButton: nil,
            cancelButton: .default(Text("OK")))
          return
        }
        if !self.processing {
          self.processing = true
          MLStyleTransferHelper.shared.applyStyle(styleImage, on: petImage) { stylizedImage in
            processing = false
            self.stylizedImage = stylizedImage
          }
        }
      }, label: {
        Text(self.processing ? "Processing..." : "Apply Style!")
          .padding(EdgeInsets.init(top: 4, leading: 8, bottom: 4, trailing: 8))
          .font(.callout)
          .background(.blue)
          .foregroundColor(.white)
          .cornerRadius(8)
      })
      .padding()
    }
    .sheet(item: self.$showImagePicker) { pickerInfo in
      return pickerInfo.picker
    }
    .alert(item: self.$showAlertMessage) { alertMessage in
      if let actionButton = alertMessage.actionButton {
        return Alert(
          title: alertMessage.title,
          message: alertMessage.message,
          primaryButton: actionButton,
          secondaryButton: alertMessage.cancelButton)
      } else {
        return Alert(
          title: alertMessage.title,
          message: alertMessage.message,
          dismissButton: alertMessage.cancelButton)
      }
    }
  }
}

struct ContentView_Previews: PreviewProvider {
  static var previews: some View {
    ContentView()
  }
}
3. ImagePicker.swift
import Foundation
import SwiftUI
import UIKit

struct PickerView: UIViewControllerRepresentable {
  @Binding var selectedImage: UIImage?
  @Environment(\.presentationMode) private var presentationMode
  func makeUIViewController(context: Context) -> UIImagePickerController {
    let imagePicker = UIImagePickerController()
    imagePicker.sourceType = .photoLibrary
    imagePicker.delegate = context.coordinator
    return imagePicker
  }
  func makeCoordinator() -> Coordinator {
    Coordinator { image in
      self.selectedImage = image
      self.presentationMode.wrappedValue.dismiss()
    }
  }
  func updateUIViewController(_ uiViewController: UIImagePickerController, context: Context) {
  }
  // Coordinator -
  final class Coordinator: NSObject, UIImagePickerControllerDelegate, UINavigationControllerDelegate {
    private let onComplete: (UIImage?) -> Void
    init(withCompletion onComplete: @escaping (UIImage?) -> Void) {
      self.onComplete = onComplete
    }
    func imagePickerController(_ picker: UIImagePickerController, didFinishPickingMediaWithInfo info: [UIImagePickerController.InfoKey: Any]) {
      if let image = info[.originalImage] as? UIImage {
        self.onComplete(image.upOrientationImage())
      }
    }
    func imagePickerControllerDidCancel(_ picker: UIImagePickerController) {
      self.onComplete(nil)
    }
  }
}
4. MLStyleTransferHelper.swift
import Foundation
import SwiftUI
import UIKit

struct MLStyleTransferHelper {
  static var shared = MLStyleTransferHelper()
  private var trainedModelPath: URL?
  mutating func applyStyle(_ styleImg: UIImage, on petImage: UIImage, onCompletion: @escaping (UIImage?) -> Void) {
    let sessionID = UUID()
    let sessionDir = Constants.Path.sessionDir.appendingPathComponent(sessionID.uuidString, isDirectory: true)
    debugPrint("Starting session in directory: \(sessionDir)")
    let petImagePath = Constants.Path.documentsDir.appendingPathComponent("MyPetImage.jpeg")
    let styleImagePath = Constants.Path.documentsDir.appendingPathComponent("StyleImage.jpeg")
    guard
      let petImageURL = petImage.saveImage(path: petImagePath),
      let styleImageURL = styleImg.saveImage(path: styleImagePath)
    else {
      debugPrint("Error Saving the image to disk.")
      return onCompletion(nil)
    }
    do {
      try FileManager.default.createDirectory(at: sessionDir, withIntermediateDirectories: true)
    } catch {
      debugPrint("Error creating directory: \(error.localizedDescription)")
      return onCompletion(nil)
    }
    // 1
    MLModelTrainer.trainModel(using: styleImageURL, validationImage: petImageURL, sessionDir: sessionDir) { modelPath in
      guard
        let aModelPath = modelPath
      else {
        debugPrint("Error creating the ML model.")
        return onCompletion(nil)
      }
      // 2
      MLPredictor.predictUsingModel(aModelPath, inputImage: petImage) { stylizedImage in
        onCompletion(stylizedImage)
      }
    }
  }
}
5. MLModelTrainer.swift
import Foundation
import CreateML
import Combine

enum MLModelTrainer {
  private static var subscriptions = Set<AnyCancellable>()
  static func trainModel(using styleImage: URL, validationImage: URL, sessionDir: URL, onCompletion: @escaping (URL?) -> Void) {
    // 1
    let dataSource = MLStyleTransfer.DataSource.images(
      styleImage: styleImage,
      contentDirectory: Constants.Path.trainingImagesDir ?? Bundle.main.bundleURL,
      processingOption: nil)
    // 2
    let sessionParams = MLTrainingSessionParameters(
      sessionDirectory: sessionDir,
      reportInterval: Constants.MLSession.reportInterval,
      checkpointInterval: Constants.MLSession.checkpointInterval,
      iterations: Constants.MLSession.iterations)
    // 3
    let modelParams = MLStyleTransfer.ModelParameters(
      algorithm: .cnn,
      validation: .content(validationImage),
      maxIterations: Constants.MLModelParam.maxIterations,
      textelDensity: Constants.MLModelParam.styleDensity,
      styleStrength: Constants.MLModelParam.styleStrength)
    // 4
    guard let job = try? MLStyleTransfer.train(
      trainingData: dataSource,
      parameters: modelParams,
      sessionParameters: sessionParams) else {
      onCompletion(nil)
      return
    }
    // 5
    let modelPath = sessionDir.appendingPathComponent(Constants.Path.modelFileName)
    job.result.sink(receiveCompletion: { result in
      debugPrint(result)
    }, receiveValue: { model in
      do {
        try model.write(to: modelPath)
        onCompletion(modelPath)
        return
      } catch {
        debugPrint("Error saving ML Model: \(error.localizedDescription)")
      }
      onCompletion(nil)
    })
    .store(in: &subscriptions)
  }
}
6. MLPredictor.swift
import Foundation
import UIKit
import Vision
import CoreML

enum MLPredictor {
  static func predictUsingModel(_ modelPath: URL, inputImage: UIImage, onCompletion: @escaping (UIImage?) -> Void) {
    // 1
    guard
      let compiledModel = try? MLModel.compileModel(at: modelPath),
      let mlModel = try? MLModel.init(contentsOf: compiledModel)
    else {
      debugPrint("Error reading the ML Model")
      return onCompletion(nil)
    }
    // 2
    let imageOptions: [MLFeatureValue.ImageOption: Any] = [
      .cropAndScale: VNImageCropAndScaleOption.centerCrop.rawValue
    ]
    guard
      let cgImage = inputImage.cgImage,
      let imageConstraint = mlModel.modelDescription.inputDescriptionsByName["image"]?.imageConstraint,
      let inputImg = try? MLFeatureValue(cgImage: cgImage, constraint: imageConstraint, options: imageOptions),
      let inputImage = try? MLDictionaryFeatureProvider(dictionary: ["image": inputImg])
    else {
      return onCompletion(nil)
    }
    // 3
    guard
      let stylizedImage = try? mlModel.prediction(from: inputImage),
      let imgBuffer = stylizedImage.featureValue(for: "stylizedImage")?.imageBufferValue
    else {
      return onCompletion(nil)
    }
    let stylizedUIImage = UIImage(withCVImageBuffer: imgBuffer)
    return onCompletion(stylizedUIImage)
  }
}
7. Constants.swift
import Foundation

enum Constants {
  enum Path {
    static let trainingImagesDir = Bundle.main.resourceURL?.appendingPathComponent("TrainingData")
    static var documentsDir: URL = {
      return FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0]
    }()
    static let sessionDir = documentsDir.appendingPathComponent("Session", isDirectory: true)
    static let modelFileName = "StyleTransfer.mlmodel"
    static let presetStyle1 = "PresetStyle_1"
  }
  enum MLSession {
    static var iterations = 100
    static var reportInterval = 50
    static var checkpointInterval = 25
  }
  enum MLModelParam {
    static var maxIterations = 200
    static var styleDensity = 128 // Multiples of 4
    static var styleStrength = 5 // Range 1 to 10
  }
}
8. UIImage+Utilities.swift
import Foundation
import UIKit
import VisionKit

extension UIImage {
  func saveImage(path: URL) -> URL? {
    guard
      let data = self.jpegData(compressionQuality: 0.8),
      (try? data.write(to: path)) != nil
    else {
      return nil
    }
    return path
  }
  convenience init?(withCVImageBuffer cvImageBuffer: CVImageBuffer) {
    let ciImage = CIImage(cvImageBuffer: cvImageBuffer)
    let context = CIContext.init(options: nil)
    guard
      let cgImage = context.createCGImage(ciImage, from: ciImage.extent)
    else {
      return nil
    }
    self.init(cgImage: cgImage)
  }
  func upOrientationImage() -> UIImage? {
    switch imageOrientation {
    case .up:
      return self
    default:
      UIGraphicsBeginImageContextWithOptions(size, false, scale)
      draw(in: CGRect(origin: .zero, size: size))
      let result = UIGraphicsGetImageFromCurrentImageContext()
      UIGraphicsEndImageContext()
      return result
    }
  }
}

後記

本篇主要講述了在iOS設備上使用Style Transfer創建一個自定義圖像濾波器,感興趣的給個贊或者關注~~~

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章