LeCun点赞的红石神经网络,这里有原理、代码和升级版!-【原创】

TL;DR


  • 第一,介绍了红石神经网络,和 mnist 数据集,以及如何用 pythorch 实现 LeNet-5 得到一个不错的数字识别模型。

  • 第二,介绍了通过 opencv + python + pytorch,实现车牌中的数字识别。

  • 最后,是 mnist train 和 test 的代码,以及车牌数字识别的代码。

红石神经网络与MNIST


【Minecraft】世界首个纯红石神经网络!B站排名最高13名,截止今日点赞 45万、 播放 400万。此间,还得到了 LeCun 的注意,不仅转发了作品,还带上了“Very meta”的评价。

“红石粉(Redstone Dust)是一种放置后可以传输红石信号的物品。在Minecraft中能够被红石控制的机械类别几乎覆盖了你能够想象到的极限,小到简单机械(如自动门、光开关、频闪电源),大到占地巨大的电梯、自动农场、盾构机、小游戏平台,甚至游戏内建的计算机。”也就是说通过红石,我们可以创造出二级管、神经元等等。

大佬开源了 git 代码,可以在B站视频介绍里获取到。红石神经网络,是基于LeCun 在 98年 提出的 LeNet-5 网络结构,在 MNIST 数据集上训练,实现了在 MC 中手写数字识别的能力。

MNIST 数据集,包含了 7万条数据,6 万条训练数据,及 1 万条测试数据。每一条数据包含一个 28×28 的图片,及对应 label。

LeNet-5

LeNet-5 的网络结构如上,包含 3 个卷积层,2 个池化层,以及 2 个全连接层,和输出log_softmax,图中的标志 C@MxN,C 表示通道数,MXN 表示为矩阵大小。以下依次介绍,几个层的原理和作用。

卷积,从滤波的视角看可以起到平滑和去噪的效果,同时是特征提取器,有助于我们找到特定的局部图像特征(如边缘)。卷积后的大小,卷积变化后的图片大小,假设卷积核大小为 (x, y), 则一个 M x N 的矩阵变换为 M-x+1, N-y+1。



B站Up主:辰占鳌头

池化,也成为子采样层(subsampling layer),其作用是进行特征选择,降低特征数量,并从而减少参数数量。一般通过简单的最大值、最小值或平均值操作完成。假设池化核大小为 (x,y),变换为 M/x, N/y。


全连接层,用来做矩阵线性变换。log_softmax 即 log(softmax(x)) ,解决了 softmax 的 函数上溢和下溢的问题,softmax 本身是对输出的 logits 做归一化的过程。

了解了池化、卷积、全连接和 log_softmax 的作用。照搬 LeNet-5,得到如下的 pytorch 模型定义。(新手推荐使用 pytorch 上手,不是说 tf 不好,而在于 tf 的理解成本太高。以 HuggingFace 为例, 平台有 85% 模型是 PyTorch 独有, 8% 是 TensorFlow 独有。可见一斑…. )

import torch.nn as nn
import torch.nn.functional as F
import collections

class LeNet5(nn.Module):

    def __init__(self):    
        super(LeNet5, self).__init__()
        # input x = 28 * 28 
        self.conv = nn.Sequential( collections.OrderedDict([
            ('c1', nn.Conv2d(1, 6, kernel_size=(5,5))),
            ('r1', nn.ReLU()),
            ('s2', nn.MaxPool2d(kernel_size=(2,2), stride=2)), 
            ('c3', nn.Conv2d(6, 16, kernel_size=(5,5))),
            ('r3', nn.ReLU()), 
            ('s4', nn.MaxPool2d(kernel_size=(2,2), stride=2)), 
            ('c5', nn.Conv2d(16, 120, kernel_size=(5,5))),
            ('r5', nn.ReLU())
        ]))

        self.fc = nn.Sequential( collections.OrderedDict([
            ('f6', nn.Linear(120, 84)),
            ('r6', nn.ReLU()),  
            ('f7', nn.Linear(84, 10)),
            ('r7', nn.ReLU()),   
            ('sig7', nn.LogSoftmax(dim=1))
        ]))

        def forward(self, x):
            y = self.conv(x)
            y = y.view(-1, 120)
            y = self.fc(y)
        return y 

train.py 和 eval.py 的代码在文末,已添加详细的注释,保证小白可懂。代码非常简单,实测 10min 搞定。两个注意点在,训练过程需要 opt.zero_grad() 清空累积梯度,预测时 with torch.no_grad(), 确定传递过程不需更新梯度。测试集上 99% 识别准确率,迭代到 200 轮接近收敛。


同时,也鼓励大家自己修改网络结构,比如在文末 model.py 中我增加了 dropout 等。当下 pytorch 这类工具基本都积木化了,虽然还是要写代码… 想直接拿现成代码跑测试的同学,关注、私信、留言。

如果仅此,那么这篇文章最大的价值仅在于,是所有被 pytroch minist 检索到的文章中,对原理和代码最为深入浅出的。

为了升级难度,下面我们用代码实现车牌中的数字识别。

车牌数字识别


车牌识别是干啥的,咱就不展开了… 下面介绍如何用 python + opencv + pytorch 实现,自己的车牌数字识别程序。输入是一张带有车牌的照片,输出是检测到的字符区块与预测结果。


仅识别数字部分,白色车牌数字切片图片,右下橘色为预测结果

流程逻辑,大致可分为‘车牌区域提取’(下方1-3子图)、‘车牌字符分割’(下方2-2子图),以及‘图像转文字’(下方的2-3子图)三部分。


识别算法原理

  1. ‘车牌区域提取’:对原始图像做 滤波 Sobe边缘检测 X轴向的闭运算 Y轴向的开运算 等以及 形态学处理 ,得到对应的 二值化图像 。以该图像做 轮廓提取 ,得到符合车牌规范的轮廓候选集,例如我们定义长宽比在2倍到4倍之间的。
  2. ‘车牌字符分割’:得到车牌区域后, 去除噪声边框 ,如车牌的黑色包边。在去除边框后,再经过去噪和形态运算,做车牌区域内的 轮廓提取 ,筛选得到符合数字的轮廓座位数字候选,并过滤掉无效轮廓,例如车牌铆钉区域等,例如我们定义轮廓高要占满至少大于 80% 的图片高度,轮廓宽度不能超过图片宽度的 1/4,但大于 1/20 (数字1)。
  3. ‘图像转文字’:上面训练完毕的 mnist 模型,只能识别数字,这里的代码仅支持车牌中的数字识别,将提取到的潜在车牌字符图片,做 图片转换入参 为 n x 28 x 28。还可以进一步把数字、汉字和字母的识别,统一做到预测模型中。(精力受限,后续展开)

滤波去噪、边缘检测、形态学运算,目的都是为了通过降噪和形态变换,以提升轮廓提取的准确性。

滤波去噪,含均值滤波 、高斯滤波等。原理是拿周边的n个近邻像素点,做像素和色彩的计算。用以降低噪声。例如中值滤波,顾名思义,是以其周围指定范围的邻接像素色彩排序后,以它们的中值做当前像素颜色。可以做到平滑、降噪的作用,高斯滤波亦然。以下图为例,就是对(x,y)点的像素值赋值,不同滤波方式,对应的计算公式不同。



图片来自知乎:圆圆要学习

边缘检测,目的是标识数字图像中亮度变化明显的点,图像属性中的显著变化通常反映了属性的重要事件和变化。



形态学运算,包含开运算和闭运算,其中开运算先‘腐蚀’再‘膨胀’,闭运算先‘膨胀’再‘腐蚀’。

  • 腐蚀,边缘内部切割,瘦身物体轮廓(边缘),分离物体间的连接,消除离散点。

  • 膨胀,边缘外部塑造,丰满物体轮廓(边缘),填充物体间的孔洞,强化离散点。

  • 闭运算,先膨胀后腐蚀,可以用来弥合窄距沟壑,消除物体间小的孔洞,填补轮廓线中的断裂。

  • 开运算,先腐蚀后膨胀,可以用来平滑物体轮廓,断开物体间窄的连接,消除轮廓边沿的尖刺。



轮廓提取,contours,hierarchy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE),opencv 自带提供了轮廓提取的方法,其中 img 一般为经过边缘检测和形态处理的二值图、灰度图。

  • 第二个参数,为轮廓检测方式:RETR_EXTERNA 只检测最大外围的轮廓,RETR_LIST 检测所有的轮廓,包括内围、外围轮廓,RETR_CCOMP 检测所有的轮廓,若外围包含内围轮廓,则内围内的所有轮廓均归属于顶层结构。RETR_TREE, 检测所有轮廓,建立一个等级树结构。

  • 第三个参数,为轮廓保留方式:CHAIN_APPROX_NONE 保存物体边界上所有连续的轮廓,CHAIN_APPROX_SIMPLE 仅保存轮廓的拐点信息,CHAIN_APPROX_TC89_L1,CHAIN_APPROX_TC89_KCOS使用 teh-Chinl chain 近似算法。

  • 算法逻辑:见论文 Satoshi Suzuki and others. Topological structural analysis of digitized binary images by border following。

  • 下图列举几个不同参数的差异,下图主要是 RETR_EXTERNAL 和其余方式会有较大差异,剩余几种的差异主要体现在 hierarchy 的结果上。



图片转换入参,将提取到的数字图片,cv2.resize(image, (28, 28)) 成 28×28 像素的,以便后续输入到 mnist 模型预测。numpy ndarray 到 tenrsor 的转换可以用 torch.from_numpy(ndarray) ,反之用 tensor.numpy()。

以下,是更多的程序结果示例。可以看到,对于车牌数字部分的识别的准确率较好。同时,也能够较好处理,不同角度、不同大小的场景。







目前的代码,能够处理大部分情形。但还存在一些问题,如输入照片尽量是车牌正前方的,因车牌内的数字轮廓检测不精准,包含部分非数字结果。后续优化点,包含但不限于:端到端的车牌区域监测,倾斜视角的矫正,无效数字轮廓过滤等。

运行环境依赖


opencv-Python
torch
torchvision
sklearn
matplotlib
torchsummary

conda create -n mnist python=3.7
conda activate mnist
pip install -r requirments.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

MINIST训练代码


model.py

model 的定义,LeNet5 为标准的结构, Net 为自定义的网络结构。

import torch.nn as nn
import torch.nn.functional as F
import collections

class LeNet5(nn.Module):
    def __init__(self):

        super(LeNet5, self).__init__()
        # input x = 28 * 28 
        self.conv = nn.Sequential( collections.OrderedDict([
            ('c1', nn.Conv2d(1, 6, kernel_size=(5,5))),
            ('r1', nn.ReLU()),
            ('s2', nn.MaxPool2d(kernel_size=(2,2), stride=2)), 
            ('c3', nn.Conv2d(6, 16, kernel_size=(5,5))),
            ('r3', nn.ReLU()), 
            ('s4', nn.MaxPool2d(kernel_size=(2,2), stride=2)), 
            ('c5', nn.Conv2d(16, 120, kernel_size=(5,5))),
            ('r5', nn.ReLU())
        ]))

        self.fc = nn.Sequential( collections.OrderedDict([
            ('f6', nn.Linear(120, 84)),
            ('r6', nn.ReLU()),  
            ('f7', nn.Linear(84, 10)),
            ('r7', nn.ReLU()),   
            ('sig7', nn.LogSoftmax())
        ]))

    def forward(self, x):
        y = self.conv(x)
        y = y.view(-1, 120)
        y = self.fc(y)
        return y

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50) #320 = 4*4*20 
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

train.py

注意 opt.zero_grad() 需要每次清空批次的梯度,因为 pytorch 是累积梯度的。

from torch.nn import init
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
from utils import mnist_data_loader, plot_loss
from model import Net

# 训练参数定义
epoch=100
batch_size=512
learning_rate = 0.001
model=Net() #LeNet5()

# 优化器和多分类损失函数
opt = torch.optim.Adam(model.parameters(), learning_rate)
criterion = nn.CrossEntropyLoss() 

# 过程记录使用
loss_sequence = []
loss_best = np.inf
record_step = 0

train_loader, _ = mnist_data_loader(batch_size)

# 模型参数初始
for layer in model.modules():
    if isinstance(layer,nn.Linear):
        init.xavier_uniform_(layer.weight)

# 迭代训练过程
for i in range(epoch):
    for batch_idx, (data,label) in enumerate(train_loader):
        # 清空累积梯度
        opt.zero_grad()  
        outputs = model(data)
        loss = criterion(outputs, label)

        # 反向传播,参数更新
        loss.backward() 
        opt.step()

        print('epoch[{}], train loss {:.6f}, dealed/records: {}/{}.'
              .format(i,loss/batch_size,(batch_idx+1)*batch_size
                      ,len(train_loader.dataset)))

        # 每100轮保存模型
        if batch_idx % 100 == 0:
            record_step+=1
            loss_sequence.append([record_step, loss/batch_size])
            if loss < loss_best:
                loss_best = loss
                torch.save(model, "model-lenet.ckpt")
# 打印 loss 历史
plot_loss(loss_sequence)

eval.py

输出 minist 在评测集上的准确率。

from cv2 import circle
import torch.nn as nn
import torch
from utils import mnist_data_loader
from model import Net, LeNet5

# 模型加载
model_path = 'model-mnist.ckpt'
model = torch.load(model_path)
batch_size = 512

# 读取测试集
_, test_loader = mnist_data_loader(batch_size)

# 无需梯度
model.eval()
with torch.no_grad():
  acc, all = 0, 0
  for batch_idx, (data,label) in enumerate(test_loader):
      outputs = model(data)
      correct = torch.sum(torch.argmax(outputs, dim=1) == label)
      all += len(label) 
      acc += correct
      print('Test batch acc {:.6f}, Dealed/Records:{}/{}'
            .format(correct/len(label), all, len(test_loader.dataset)))

  print ('total precision {:.6f}'.format(acc / all))

utils.py

主要提供了 mnist data 的读取,以及展示训练 loss 和迭代次数的关系。

def mnist_data_loader(batch_size):
    train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST('./data/', train=True, download=True,
                transform=torchvision.transforms.Compose([
                    # mnist 输入为 28x28, LeNet-5 结构第一层为 32x32 
                    torchvision.transforms.Resize(32),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.1307,), (0.3081,))
            ]))
            , batch_size=batch_size
            , shuffle=True)
    print ('train dataset size:%s'%len(train_loader.dataset))
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=False, download=True,
            transform=torchvision.transforms.Compose([
                # mnist 输入为 28x28, LeNet-5 结构第一层为 32x32 
                torchvision.transforms.Resize(32),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
        ]))
        , batch_size=batch_size
        , shuffle=False)
    print ('test dataset size:%s'%len(test_loader.dataset))
    return train_loader, test_loader

def plot_loss(loss_sequence):
    fig = plt.figure(figsize=(20,15))
    fig.autofmt_xdate()
    loss_df = pd.DataFrame(loss_sequence,columns=['time','loss'])
    plt.ylabel('loss')
    plt.xlabel('times')
    plt.plot(loss_df['loss'].values)
    plt.xticks([10,30,50,70,80,100,120,140,160,200,100])
    plt.show()

车牌数字识别代码


对每部分都做了注释,上文也阐述了轮廓监测中,去噪处理+形态变换的重要性。其中比较重要的有 x 向的闭运算的宽度的调整,会影响到车牌提取的准确性。最好的办法,是直接用端到端的检测到车牌区域,比如用 yolo 去训练一个车牌区域检测模型。想直接拿现成代码跑测试的同学,关注、私信、留言。

mnist_lpr.py

from distutils.log import debug
from genericpath import exists
from tkinter import E
import cv2
import numpy as np
from torchvision import transforms
import torch
from model import Net, LeNet5
from utils import *

# 车牌识别

class LisencePlateRecognizer():
    def __init__(self, model_path):
        self.is_debug = False
        self.model = torch.load(model_path) #load minist model
        self.hsv_blue_low = np.array([100, 43, 46]) #蓝色下界
        self.hsv_blue_up = np.array([124, 255, 255]) #蓝色上界

    def tirgger_debug(self, flg):
        self.is_debug = flg;

    # 图像去噪、边缘检测、形态化,以及二值输出

    def denoisy_2_binary_img(self, image):
        # 蓝色 masking x 高斯模糊-灰度图 
        gray_img_1 = cv2.GaussianBlur(image, (1, 1), 0)
        gray_img_1 = cv2.cvtColor(gray_img_1, cv2.COLOR_BGR2GRAY)
        gray_img_2 = self.get_bi_image_use_color_mask(image)
        gray_image = np.array(gray_img_1) * (np.array(gray_img_2))

        # x方向上的边缘检测(增强边缘信息)
        sobel_x = cv2.Sobel(gray_image, cv2.CV_16S, 1, 0)
        image = cv2.convertScaleAbs(sobel_x)

        # 图像阈值化操作——获得二值化图,THRESH_OTSU 为自动寻找最优的分割阈值
        _, image = cv2.threshold(image, 0, 255, cv2.THRESH_OTSU)

        # 闭操作, 让图片更加的丰满 
        kernel_x = cv2.getStructuringElement(cv2.MORPH_RECT, (45, 1))
        image = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel_x, iterations = 1)

        # 腐蚀(erode)和膨胀(dilate)
        kernel_x = cv2.getStructuringElement(cv2.MORPH_RECT, (30, 1))
        image = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel_x, iterations = 1)

        # 纵向是开运算, 让细节分割
        kernel_y = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 20))
        image = cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel_y, iterations = 1)

        self.desc_img_blur = image.copy()
        return image

    # 得到以蓝底为区域的二值图

    def get_bi_image_use_color_mask(self, image):  
        img = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2HSV) 
        img = cv2.GaussianBlur(img, (1, 1), 0)
        img = cv2.inRange(img, self.hsv_blue_low, self.hsv_blue_up)
        kernel = np.ones((1, 1), np.uint8)
        img = cv2.erode(img, kernel, iterations=1) 
        return img

    # 是否包含蓝底的车牌颜色

    def has_legal_back_ground(self, image, x, y, width, height):
        color_img = self.get_bi_image_use_color_mask(image[y:y + height, x:x + width])
        c, h = cv2.findContours(color_img.copy()
                                , cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        if c is not None and len(c)>0:
            return True
        return False

# 找到车牌的区域 regions of interest

    def get_car_lisence_plat_roi(self, img_denoise_bi, img_original):

        # 对降噪和联通后的图二值化后的的图像做边缘检测
        contours, _ = cv2.findContours(img_denoise_bi
                                       , cv2.RETR_EXTERNAL
                                       , cv2.CHAIN_APPROX_SIMPLE)
        best_legal_color_rate = 0.0
        plate_roi = None
        plate_roi_cands = []
        box = []
        for item in contours:
            x, y, width, height = cv2.boundingRect(item)
            # 假定车牌的宽度为长度的1.5倍 - 4.0背之间
            #  才会被当做可能的车牌轮廓区域,做进一步的检测。
            if (width > (height * 1.5)) and (width < (height * 4.0)):
                if self.has_legal_back_ground(img_original, x, y, width, height):
                    img_t = img_original[y:y + height, x:x + width]
                    img_legal_color_bi = self.get_bi_image_use_color_mask(img_t)
                    hit, rate = get_white_color_statistics(img_legal_color_bi) 
                    if rate > best_legal_color_rate:
                        best_legal_color_rate = rate
                        img_denoise_bi = img_original[y:y + height, x:x + width]
                        plate_roi_cands.append(img_denoise_bi)
                        plate_roi = img_denoise_bi
                        box.append(item)
        show_contours(box, img_original, self.is_debug)
        self.desc_car_plate = draw_contours(box, img_original, self.is_debug)
        return plate_roi_cands, plate_roi

    # 找到车牌内部的数字 region of interest

    def get_car_lisence_chars_rois(self, img_car_plate):
        contours, _ = cv2.findContours(img_car_plate
                                       , cv2.RETR_EXTERNAL
                                       , cv2.CHAIN_APPROX_SIMPLE)
        numbers_rois = []
        for box in contours:
            rect = cv2.boundingRect(box)
            w, h = rect[2], rect[3]
            if h / img_car_plate.shape[0] >= 5.0 / 10.0 \
                and w / img_car_plate.shape[1] >= 1.0/50.0 \
                and w / img_car_plate.shape[1] <= 1.0/4.0:
                numbers_rois.append(box)
        show_contours(numbers_rois, img_car_plate, self.is_debug)
        self.desc_car_plate_numbers = draw_contours(numbers_rois
                                                    , img_car_plate, self.is_debug)
        return numbers_rois

    # 把检测到的数字区块,转换成模型的预测输入 x

    def convert_to_pred_input(self, box, img_car_plate):
        numbers = []
        for item in box:
            rect = cv2.boundingRect(item)
            x, y, width, height = rect
            image = img_car_plate[y:y + height, x:x + width]
            border_size = int(image.shape[0] / 3)
            image = cv2.copyMakeBorder(image,border_size,border_size
                                       ,border_size,border_size,
                                       borderType=cv2.BORDER_CONSTANT,value=0)
            image = cv2.resize(image, (28, 28)) # minist model input size
            numbers.append([x,y,image])
        inputs = []
        for num in sorted(numbers, key=lambda x:(x[0], x[1])):
            inputs.append([num[2]])
        input = np.array(inputs)
        input = input.astype(np.float32)
        return torch.from_numpy(input)

    # 用训练好的 mnist 模型, 预测当前图形的结果

    def predict(self, img_path):
        self.model.eval()
        image = cv2.imread(img_path)
        self.desc_original_image = image.copy()
        plt_show(image, self.is_debug)

        img_denoise_bi = self.denoisy_2_binary_img(image)
        plt_show(img_denoise_bi, self.is_debug)

        car_plate_cands, img_car_plate  = self.get_car_lisence_plat_roi(
            img_denoise_bi, image)

        plate_cands = [img_car_plate] + car_plate_cands
        for img_car_plate in plate_cands:
            try:
                plt_show(img_car_plate, self.is_debug)
                self.desc_car_plate_sub = img_car_plate
                img_car_plate = remove_noise_border(img_car_plate)
                plt_show(img_car_plate, self.is_debug)
                self.desc_car_plate_clean = img_car_plate
                img_chars = self.get_car_lisence_chars_rois(img_car_plate)
                inputs = self.convert_to_pred_input(img_chars, img_car_plate)  
                with torch.no_grad():
                    preds = self.model(inputs)
                    numbs = torch.argmax(preds, dim=1)
                    self.desc_result_image = darw_result_img(inputs, numbs)
                self.show_detail_process()
                break
            except Exception:
                print (Exception)

    # 打印过程结果,方便debug。

    def show_detail_process(self):
        _, ax = plt.subplots(2,3,figsize=(20,20))
        ax[0,0].imshow(self.desc_original_image)
        ax[0,0].set_title("desc_original_image")
        ax[0,1].imshow(self.desc_img_blur)
        ax[0,1].set_title("desc_img_blur")
        ax[0,2].imshow(self.desc_car_plate)
        ax[0,2].set_title("desc_car_plate")
        ax[1,0].imshow(self.desc_car_plate_sub)
        ax[1,0].set_title("desc_car_plate_sub")
        ax[1,1].imshow(self.desc_car_plate_numbers)
        ax[1,1].set_title("desc_car_plate_numbers")
        ax[1,2].imshow(self.desc_result_image)
        ax[1,2].set_title("desc_result_image")
        plt.tight_layout()
        plt.show()

if __name__ == '__main__':
    import sys
    mnist_model_path = 'model1.ckpt'
    lpr = LisencePlateRecognizer(mnist_model_path)
    # lpr.tirgger_debug(True)
    lpr.predict(sys.argv[1].strip())

最后

更多Python知识尽在【Python都知道】公众号,欢迎大家!!
扫描下方二维码,关注公众号,了解更多Python内容


小白学堂 » LeCun点赞的红石神经网络,这里有原理、代码和升级版!-【原创】

就聊挣钱,一个带着你做副业的社群。

立即查看 了解详情