util.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # coding: utf-8
  2. import gzip
  3. import os
  4. import shutil
  5. import numpy as np
  6. import paddle.v2 as paddle
  7. def extract_param(path):
  8. """
  9. Return the parameters of weight and bias
  10. """
  11. with gzip.open(path) as f:
  12. p = paddle.parameters.Parameters.from_tar(f)
  13. names = p.names()
  14. weight, bias = [], []
  15. for name in names:
  16. if name.endswith('bias'):
  17. bias.append(p.get(name))
  18. else:
  19. weight.append(p.get(name))
  20. return weight, bias
  21. def param2header(param, path, dtype, var_name):
  22. """
  23. Trans the list of 2d np array into C++ header file
  24. """
  25. with open(path, 'w') as f:
  26. postfix = 1
  27. for matrix in param:
  28. rows = len(matrix)
  29. cols = len(matrix[0])
  30. beg = 'static {0} {1}{2}[{3}]={{'.format(dtype, var_name, postfix, rows * cols)
  31. end = '};\n'
  32. f.write(beg)
  33. for i in range(rows):
  34. for j in range(cols):
  35. if i == 0 and j == 0:
  36. f.write('{:.6f}'.format(matrix[i][j]))
  37. else:
  38. f.write(',{:.6f}'.format(matrix[i][j]))
  39. f.write('\n')
  40. f.write(end)
  41. postfix += 1
  42. # layers = len(param)
  43. # postfix = 1
  44. # beg = 'static {0} **{1} = {0}*[{2}]{{{1}{3}'.format(dtype, var_name, layers, postfix)
  45. # f.write(beg)
  46. # while postfix < layers:
  47. # postfix += 1
  48. # f.write(', {0}{1}'.format(var_name, postfix))
  49. # f.write(end)
  50. def cp_gtsrb_img(src, dst):
  51. """
  52. Copy all GTSRB images from src to dst, just like:
  53. src/000xx/000yy_000zz.ppm -> dst/000xx_000yy_000zz.ppm
  54. """
  55. folders = os.listdir(src)
  56. for folder in folders:
  57. imgs = filter(lambda fname: fname.endswith('ppm'),
  58. os.listdir(src + '/' + folder))
  59. for img in imgs:
  60. shutil.copy(src + '/' + folder + '/' + img,
  61. dst + '/' + folder + '_' + img)
  62. def generate_gtsrb_dict(dir):
  63. imgs = os.listdir(dir)
  64. gtsrb_dict = {}
  65. for img in imgs:
  66. gtsrb_dict[img] = int(img.split('_')[0])
  67. return gtsrb_dict
  68. if __name__ == '__main__':
  69. gtsrb_img_path = 'data/GTSRB/Final_Training/Images'
  70. gtsrb_cp_dst = 'data/GTSRB/Final_Training/All'
  71. param_path = 'data/params_pass_1900.tar.gz'
  72. weight, bias = extract_param(param_path)
  73. for w in weight:
  74. print w.shape
  75. for b in bias:
  76. print b.shape
  77. param2header(weight, 'weight.h', 'float', 'weight')
  78. param2header(bias, 'bias.h', 'float', 'bias')
  79. # cp_gtsrb_img(gtsrb_img_path, gtsrb_cp_dst)
  80. # generate_gtsrb_dict(gtsrb_cp_dst)