fork download
  1. import numpy as np
  2.  
  3. # Example 1: (2, 2, 3) @ (2, 3, 2) → (2, 2, 2)
  4. a = np.array([1,2,3,4,5,6,7,8,9,10,11,12], dtype=np.float32).reshape(2, 2, 3, order='F')
  5. b = np.array([1,2,3,4,5,6,7,8,9,10,11,12], dtype=np.float32).reshape(2, 3, 2, order='F')
  6.  
  7. result1 = np.matmul(a, b)
  8. print("Example 1 Shape:", result1.shape)
  9. print(result1)
  10. print()
  11.  
  12. # Example 2: Broadcasting (3, 1, 2, 3) @ (1, 4, 3, 2) → (3, 4, 2, 2)
  13. c = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
  14. dtype=np.float32).reshape(3,1,2,3, order='F')
  15. d = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
  16. dtype=np.float32).reshape(1, 4, 3, 2, order='F')
  17.  
  18. result2 = np.matmul(c, d)
  19. print("Example 2 Shape:", result2.shape)
  20. print(result2)
  21.  
  22.  
  23.  
  24. data = np.arange(1, 25, dtype=np.float32)
  25.  
  26. # reshape using order='F' to match column-major layout
  27. mat = np.reshape(data, (2, 3, 4), order='F')
  28. vec = np.array([1., 2., 3.], dtype=np.float32)
  29.  
  30. # perform matmul (NumPy handles broadcasting automatically)
  31. result = np.matmul(vec, mat) # equivalent to Tensor::matmul(vec, mat)
  32.  
  33. print("mat shape:", mat.shape)
  34. print("vec shape:", vec.shape)
  35. print("result shape:", result.shape)
  36. print(result)
  37.  
  38.  
  39. vec2 = np.array([1., 2., 3., 4.], dtype=np.float32)
  40. mat2 = np.reshape(np.arange(1, 25, dtype=np.float32), (2, 3, 4), order='F')
  41. result = np.matmul(mat2, vec2)
  42.  
  43. print(result.shape) # (2, 3)
  44. print(result)
  45.  
  46.  
Success #stdin #stdout 0.09s 23616KB
stdin
Standard input is empty
stdout
('Example 1 Shape:', (2, 2, 2))
[[[ 61. 151.]
  [ 79. 205.]]

 [[ 88. 196.]
  [112. 256.]]]
()
('Example 2 Shape:', (3, 4, 2, 2))
[[[[153. 405.]
   [198. 558.]]

  [[174. 426.]
   [228. 588.]]

  [[195. 447.]
   [258. 618.]]

  [[216. 468.]
   [288. 648.]]]


 [[[168. 456.]
   [213. 609.]]

  [[192. 480.]
   [246. 642.]]

  [[216. 504.]
   [279. 675.]]

  [[240. 528.]
   [312. 708.]]]


 [[[183. 507.]
   [228. 660.]]

  [[210. 534.]
   [264. 696.]]

  [[237. 561.]
   [300. 732.]]

  [[264. 588.]
   [336. 768.]]]]
('mat shape:', (2, 3, 4))
('vec shape:', (3,))
('result shape:', (2, 4))
[[ 22.  58.  94. 130.]
 [ 28.  64. 100. 136.]]
(2, 3)
[[130. 150. 170.]
 [140. 160. 180.]]