File indexing completed on 2024-04-06 12:28:17
0001
0002
0003
0004
0005
0006
0007 package GenMul;
0008
0009 my $G_vec_width = 1;
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021 package GenMul::MBase;
0022
0023 use Carp;
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043 sub new
0044 {
0045 my $proto = shift;
0046 my $class = ref($proto) || $proto;
0047 my $S = {@_};
0048 bless($S, $class);
0049
0050
0051
0052 croak "name must be set" unless defined $S->{name};
0053
0054 $S->{class} = $class;
0055
0056 return $S;
0057 }
0058
0059 sub mat_size
0060 {
0061 die "max_size() should be overriden in concrete matrix class";
0062 }
0063
0064 sub idx
0065 {
0066 die "idx() should be overriden in concrete matrix class";
0067 }
0068
0069 sub row_col_in_range
0070 {
0071 my ($S, $i, $j) = @_;
0072
0073 return $i >= 0 and $i < $S->{M} and $j >= 0 and $j < $S->{N};
0074 }
0075
0076 sub set_pattern
0077 {
0078 my ($S, $pstr) = @_;
0079
0080 @{$S->{pattern}} = split /\s+/, $pstr;
0081
0082 croak "set_pattern number of entries does not match matrix size"
0083 unless scalar @{$S->{pattern}} == $S->mat_size();
0084
0085 croak "set_pattern() input string contains invalid entry"
0086 if grep {$_ !~ /0|1|x/} @{$S->{pattern}};
0087 }
0088
0089 sub pattern
0090 {
0091 my ($S, $idx) = @_;
0092
0093 die "pattern called with bad index."
0094 unless $idx >=0 and $idx < $S->mat_size();
0095
0096 return defined $S->{pattern} ? $S->{pattern}[$idx] : 'x';
0097 }
0098
0099 sub reg_name
0100 {
0101 my ($S, $idx) = @_;
0102
0103 die "reg_name called with bad index."
0104 unless $idx >=0 and $idx < $S->mat_size();
0105
0106 return "$S->{name}_${idx}";
0107 }
0108
0109 sub print_info
0110 {
0111 my ($S) = @_;
0112
0113 print "Class='$S->{class}', M=$S->{M}, N=$S->{N}, name='$S->{name}'\n";
0114 }
0115
0116 sub print_pattern
0117 {
0118 my ($S) = @_;
0119
0120 for (my $i = 0; $i < $S->{M}; ++$i)
0121 {
0122 for (my $j = 0; $j < $S->{N}; ++$j)
0123 {
0124 print $S->pattern($S->idx($i, $j)), " ";
0125 }
0126 print "\n";
0127 }
0128 }
0129
0130
0131
0132
0133
0134 package GenMul::Matrix; @ISA = ('GenMul::MBase');
0135
0136 use Carp;
0137
0138 sub new
0139 {
0140 my $proto = shift;
0141 my $S = $proto->SUPER::new(@_);
0142
0143 croak "M not set for $S->{class}" unless defined $S->{M};
0144
0145 croak "N not set for $S->{class}" unless defined $S->{N};
0146
0147 return $S;
0148 }
0149
0150 sub mat_size
0151 {
0152 my ($S) = @_;
0153
0154 return $S->{M} * $S->{N};
0155 }
0156
0157 sub idx
0158 {
0159 my ($S, $i, $j) = @_;
0160
0161 confess "$S->{class}::idx() i out of range"
0162 if $i < 0 or $i >= $S->{M};
0163
0164 confess "$S->{class}::idx() j out of range"
0165 if $j < 0 or $j >= $S->{N};
0166
0167 return $i * $S->{N} + $j;
0168 }
0169
0170
0171
0172
0173
0174 package GenMul::MatrixSym; @ISA = ('GenMul::MBase');
0175
0176 use Carp;
0177
0178
0179 my @Offs;
0180 @Offs[2] = [ 0, 1, 1, 2 ];
0181 @Offs[3] = [ 0, 1, 3, 1, 2, 4, 3, 4, 5 ];
0182 @Offs[4] = [ 0, 1, 3, 6, 1, 2, 4, 7, 3, 4, 5, 8, 6, 7, 8, 9 ];
0183 @Offs[5] = [ 0, 1, 3, 6, 10, 1, 2, 4, 7, 11, 3, 4, 5, 8, 12, 6, 7, 8, 9, 13, 10, 11, 12, 13, 14 ];
0184 @Offs[6] = [ 0, 1, 3, 6, 10, 15, 1, 2, 4, 7, 11, 16, 3, 4, 5, 8, 12, 17, 6, 7, 8, 9, 13, 18, 10, 11, 12, 13, 14, 19, 15, 16, 17, 18, 19, 20 ];
0185
0186 sub new
0187 {
0188 my $proto = shift;
0189 my $S = $proto->SUPER::new(@_);
0190
0191 croak "M not set for $S->{class}" unless defined $S->{M};
0192
0193 croak "N should not be set or should be equal to M for $S->{class}"
0194 if defined $S->{N} and $S->{N} != $S->{M};
0195
0196 die "Offset array not defined for this dimension"
0197 unless defined @Offs[$S->{M}];
0198
0199 die "Offset array of wrong dimension"
0200 unless scalar @{$Offs[$S->{M}]} == $S->{M} * $S->{M};
0201
0202 $S->{N} = $S->{M} unless defined $S->{N};
0203
0204 return $S;
0205 }
0206
0207 sub mat_size
0208 {
0209 my ($S) = @_;
0210
0211 return ($S->{M} + 1) * $S->{M} / 2;
0212 }
0213
0214 sub idx
0215 {
0216 my ($S, $i, $j) = @_;
0217
0218 confess "$S->{class}::idx() i out of range"
0219 if $i < 0 or $i >= $S->{M};
0220
0221 confess "$S->{class}::idx() j out of range"
0222 if $j < 0 or $j >= $S->{N};
0223
0224 return $Offs[$S->{M}][$i * $S->{N} + $j];
0225 }
0226
0227
0228
0229
0230
0231
0232 package GenMul::MatrixTranspose; @ISA = ('GenMul::MBase');
0233
0234 use Carp;
0235 use Scalar::Util 'blessed';
0236
0237
0238 sub new
0239 {
0240 my $proto = shift;
0241 my $mat = shift;
0242
0243 croak "Argument for $S->{class} is not a GenMul::MBase"
0244 unless blessed $mat and $mat->isa("GenMul::MBase");
0245
0246 my $S = $proto->SUPER::new(@_, 'name'=>$mat->{name});
0247
0248
0249 $S->{matrix} = $mat;
0250
0251
0252
0253
0254 $S->{M} = $S->{matrix}{N};
0255 $S->{N} = $S->{matrix}{M};
0256
0257 return $S;
0258 }
0259
0260 sub mat_size
0261 {
0262 my ($S) = @_;
0263
0264 return $S->{matrix}->mat_size();
0265 }
0266
0267 sub idx
0268 {
0269 my ($S, $i, $j) = @_;
0270
0271 return $S->{matrix}->idx($j, $i);
0272 }
0273
0274 sub pattern
0275 {
0276 my ($S, $idx) = @_;
0277
0278 return $S->{matrix}->pattern($idx);
0279 }
0280
0281 sub print_info
0282 {
0283 my ($S) = @_;
0284
0285 print "Transpose of ";
0286 $S->{matrix}->print_info();
0287 print " ";
0288 $S->SUPER::print_info();
0289 }
0290
0291
0292
0293
0294
0295
0296
0297
0298 package GenMul::Multiply;
0299
0300 use Carp;
0301 use Scalar::Util 'blessed';
0302
0303 use warnings;
0304
0305
0306
0307
0308 sub new
0309 {
0310 my $proto = shift;
0311 my $class = ref($proto) || $proto;
0312 my $S = {@_};
0313 bless($S, $class);
0314
0315 $S->{prefix} = " " unless defined $S->{prefix};
0316 $S->{vectype} = "IntrVec_t" unless defined $S->{vectype};
0317
0318 $S->{class} = $class;
0319
0320 return $S;
0321 }
0322
0323 sub check_multiply_arguments
0324 {
0325 my ($S, $a, $b, $c) = @_;
0326
0327 croak "Input a is not a GenMul::MBase"
0328 unless blessed $a and $a->isa("GenMul::MBase");
0329
0330 croak "Input b is not a GenMul::MBase"
0331 unless blessed $b and $b->isa("GenMul::MBase");
0332
0333 croak "Input c is not a GenMul::MBase"
0334 unless blessed $c and $c->isa("GenMul::MBase");
0335
0336 unless ($S->{no_size_check})
0337 {
0338 croak "Input matrices a and b not compatible"
0339 unless $a->{N} == $b->{M};
0340
0341 croak "Result matrix c of wrong dimensions"
0342 unless $c->{M} == $a->{M} and $c->{N} == $b->{N};
0343 }
0344 else
0345 {
0346 carp "Input matrices a and b not compatible -- running with no_size_check"
0347 unless $a->{N} == $b->{M};
0348
0349 carp "Result matrix c of wrong dimensions -- running with no_size_check"
0350 unless $c->{M} == $a->{M} and $c->{N} == $b->{N};
0351 }
0352
0353 croak "Result matrix c should not be a transpose (or check & implement this case in GenMul code)"
0354 if $c->isa("GenMul::MatrixTranspose");
0355
0356 croak "Result matrix c has a pattern defined, this is not yet supported (but shouldn't be too hard)."
0357 if defined $c->{pattern};
0358
0359 carp "Result matrix c is symmetric, GenMul hopes you know what you're doing"
0360 if $c->isa("GenMul::MatrixSym");
0361
0362 $S->{a}{mat} = $a;
0363 $S->{b}{mat} = $b;
0364 }
0365
0366 sub push_out
0367 {
0368 my $S = shift;
0369
0370 push @{$S->{out}}, join "", @_;
0371 }
0372
0373 sub unshift_out
0374 {
0375 my $S = shift;
0376
0377 unshift @{$S->{out}}, join "", @_;
0378 }
0379
0380 sub handle_all_zeros_ones
0381 {
0382 my ($S, $zeros, $ones) = @_;
0383
0384 if ($zeros or $ones)
0385 {
0386 my @zo;
0387
0388 push @zo, "#ifdef AVX512_INTRINSICS";
0389
0390 push @zo, "$S->{vectype} all_zeros = { " . join(", ", (0) x 16) . " };"
0391 if $zeros;
0392
0393 push @zo, "$S->{vectype} all_ones = { " . join(", ", (1) x 16) . " };"
0394 if $ones;
0395
0396 push @zo, "#else";
0397
0398 push @zo, "$S->{vectype} all_zeros = { " . join(", ", (0) x 8) . " };"
0399 if $zeros;
0400
0401 push @zo, "$S->{vectype} all_ones = { " . join(", ", (1) x 8) . " };"
0402 if $ones;
0403
0404 push @zo, "#endif";
0405
0406 push @zo, "";
0407
0408 for $zol (reverse @zo)
0409 {
0410 $S->unshift_out($zol);
0411 }
0412 }
0413 }
0414
0415 sub delete_temporaries
0416 {
0417 my ($S) = @_;
0418
0419 for $k ('idx', 'pat')
0420 {
0421 delete $S->{a};
0422 delete $S->{b};
0423 }
0424 }
0425
0426 sub delete_loop_temporaries
0427 {
0428 my ($S) = @_;
0429
0430 for $k ('idx', 'pat')
0431 {
0432 delete $S->{a}{$k};
0433 delete $S->{b}{$k};
0434 }
0435 }
0436
0437 sub generate_index_and_pattern
0438 {
0439 my ($S, $x, $i1, $i2) = @_;
0440
0441 if ($S->{no_size_check} and not $S->{$x}{mat}->row_col_in_range($i, $k))
0442 {
0443 $S->{$x}{pat} = '0';
0444 }
0445 else
0446 {
0447 $S->{$x}{idx} = $S->{$x}{mat}->idx($i1, $i2);
0448 $S->{$x}{pat} = $S->{$x}{mat}->pattern ($S->{$x}{idx});
0449 }
0450 }
0451
0452 sub generate_indices_and_patterns_for_multiplication
0453 {
0454
0455
0456 my ($S, $i, $j, $k) = @_;
0457
0458 $S->delete_loop_temporaries();
0459
0460 $S->generate_index_and_pattern('a', $i, $k);
0461 $S->generate_index_and_pattern('b', $k, $j);
0462 }
0463
0464
0465
0466 sub generate_addend_standard
0467 {
0468 my ($S, $x, $y) = @_;
0469
0470 return undef if $S->{$x}{pat} eq '0' or $S->{$y}{pat} eq '0';
0471 return "1" if $S->{$x}{pat} eq '1' and $S->{$y}{pat} eq '1';
0472
0473 my $xstr = sprintf "$S->{$x}{mat}{name}\[%2d*N+n]", $S->{$x}{idx};
0474 my $ystr = sprintf "$S->{$y}{mat}{name}\[%2d*N+n]", $S->{$y}{idx};
0475
0476 return $xstr if $S->{$y}{pat} eq '1';
0477 return $ystr if $S->{$x}{pat} eq '1';
0478
0479 return "${xstr}*${ystr}";
0480 }
0481
0482 sub multiply_standard
0483 {
0484
0485
0486
0487
0488
0489 check_multiply_arguments(@_);
0490
0491 my ($S, $a, $b, $c) = @_;
0492
0493 my $is_c_symmetric = $c->isa("GenMul::MatrixSym");
0494
0495
0496 my $k_max = $a->{N} <= $b->{M} ? $a->{N} : $b->{M};
0497
0498 for (my $i = 0; $i < $c->{M}; ++$i)
0499 {
0500 my $j_max = $is_c_symmetric ? $i + 1 : $c->{N};
0501
0502 for (my $j = 0; $j < $j_max; ++$j)
0503 {
0504 my $x = $c->idx($i, $j);
0505
0506 printf "$S->{prefix}$c->{name}\[%2d*N+n\] = ", $x;
0507
0508 my @sum;
0509
0510 for (my $k = 0; $k < $k_max; ++$k)
0511 {
0512 $S->generate_indices_and_patterns_for_multiplication($i, $j, $k);
0513
0514 my $addend = $S->generate_addend_standard('a', 'b');
0515
0516 push @sum, $addend if defined $addend;
0517 }
0518 if (@sum)
0519 {
0520 print join(" + ", @sum), ";";
0521 }
0522 else
0523 {
0524 print "0;"
0525 }
0526 print "\n";
0527 }
0528 }
0529
0530 $S->delete_temporaries();
0531 }
0532
0533
0534
0535 sub generate_addend_gpu
0536 {
0537 my ($S, $x, $y) = @_;
0538
0539 return undef if $S->{$x}{pat} eq '0' or $S->{$y}{pat} eq '0';
0540 return "1" if $S->{$x}{pat} eq '1' and $S->{$y}{pat} eq '1';
0541
0542 my $xstr = sprintf "$S->{$x}{mat}{name}\[%2d*$S->{$x}{mat}{name}N+$S->{$x}{mat}{name}n]", $S->{$x}{idx};
0543 my $ystr = sprintf "$S->{$y}{mat}{name}\[%2d*$S->{$y}{mat}{name}N+$S->{$y}{mat}{name}n]", $S->{$y}{idx};
0544
0545 return $xstr if $S->{$y}{pat} eq '1';
0546 return $ystr if $S->{$x}{pat} eq '1';
0547
0548 return "${xstr}*${ystr}";
0549 }
0550
0551 sub multiply_gpu
0552 {
0553
0554
0555
0556
0557
0558 check_multiply_arguments(@_);
0559
0560 my ($S, $a, $b, $c) = @_;
0561
0562 my $is_c_symmetric = $c->isa("GenMul::MatrixSym");
0563
0564
0565 my $k_max = $a->{N} <= $b->{M} ? $a->{N} : $b->{M};
0566
0567 for (my $i = 0; $i < $c->{M}; ++$i)
0568 {
0569 my $j_max = $is_c_symmetric ? $i + 1 : $c->{N};
0570
0571 for (my $j = 0; $j < $j_max; ++$j)
0572 {
0573 my $x = $c->idx($i, $j);
0574
0575 printf "$S->{prefix}$c->{name}\[%2d*$c->{name}N+$c->{name}n\] = ", $x;
0576
0577 my @sum;
0578
0579 for (my $k = 0; $k < $k_max; ++$k)
0580 {
0581 $S->generate_indices_and_patterns_for_multiplication($i, $j, $k);
0582
0583 my $addend = $S->generate_addend_gpu('a', 'b');
0584
0585 push @sum, $addend if defined $addend;
0586 }
0587 if (@sum)
0588 {
0589 print join(" + ", @sum), ";";
0590 }
0591 else
0592 {
0593 print "0;"
0594 }
0595 print "\n";
0596 }
0597 }
0598
0599 $S->delete_temporaries();
0600 }
0601
0602
0603
0604 sub load_if_needed
0605 {
0606 my ($S, $x) = @_;
0607
0608 my $idx = $S->{$x}{idx};
0609
0610 my $reg = $S->{$x}{mat}->reg_name($idx);
0611
0612 if ($S->{$x}{cnt}[$idx] == 0)
0613 {
0614 $S->push_out("$S->{vectype} ${reg} = LD($S->{$x}{mat}{name}, $idx);");
0615 ++$S->{tick};
0616 }
0617
0618 ++$S->{$x}{cnt}[$idx];
0619
0620 return $reg;
0621 }
0622
0623 sub store
0624 {
0625 my ($S, $mat, $idx) = @_;
0626
0627 my $reg = $mat->reg_name(${idx});
0628
0629 $S->push_out("ST($mat->{name}, ${idx}, ${reg});");
0630
0631 return $reg;
0632 }
0633
0634 sub multiply_intrinsic
0635 {
0636 check_multiply_arguments(@_);
0637
0638 my ($S, $a, $b, $c) = @_;
0639
0640 $S->{tick} = 0;
0641
0642 $S->{out} = [];
0643
0644
0645
0646 my (@cc, @to_store);
0647 @cc = (0) x $c->mat_size();
0648
0649 $S->{a}{cnt} = [ (0) x $a->mat_size() ];
0650 $S->{b}{cnt} = [ (0) x $b->mat_size() ];
0651
0652 my $need_all_zeros = 0;
0653 my $need_all_ones = 0;
0654
0655 my $is_c_symmetric = $c->isa("GenMul::MatrixSym");
0656
0657
0658 my $k_max = $a->{N} <= $b->{M} ? $a->{N} : $b->{M};
0659
0660 for (my $i = 0; $i < $c->{M}; ++$i)
0661 {
0662 my $j_max = $is_c_symmetric ? $i + 1 : $c->{N};
0663
0664 for (my $k = 0; $k < $k_max; ++$k)
0665 {
0666 for (my $j = 0; $j < $j_max; ++$j)
0667 {
0668 my $x = $c->idx($i, $j);
0669
0670 $S->generate_indices_and_patterns_for_multiplication($i, $j, $k);
0671
0672 if ($S->{a}{pat} ne '0' and $S->{b}{pat} ne '0')
0673 {
0674 my ($areg, $breg, $sreg);
0675
0676 if ($S->{a}{pat} eq '1' and $S->{b}{pat} eq '1')
0677 {
0678 $need_all_ones = 1;
0679 $sreg = "all_ones";
0680 }
0681 elsif ($S->{b}{pat} eq '1')
0682 {
0683 $sreg = $S->load_if_needed('a');
0684 }
0685 elsif ($S->{a}{pat} eq '1')
0686 {
0687 $sreg = $S->load_if_needed('b');
0688 }
0689 else
0690 {
0691 $areg = $S->load_if_needed('a');
0692 $breg = $S->load_if_needed('b');
0693 }
0694
0695 my $creg = $c->reg_name($x);
0696
0697 if ($cc[$x] == 0)
0698 {
0699 my $op = defined $sreg ? "${sreg}" : "MUL(${areg}, ${breg})";
0700
0701 $S->push_out("$S->{vectype} ${creg} = ", $op, ";");
0702 }
0703 else
0704 {
0705 my $op = defined $sreg ?
0706 "ADD(${sreg}, ${creg})" :
0707 "FMA(${areg}, ${breg}, ${creg})";
0708
0709 $S->push_out("${creg} = ", $op, ";");
0710 }
0711
0712 ++$cc[$x];
0713 ++$S->{tick};
0714 }
0715
0716 if ($k + 1 == $k_max)
0717 {
0718 if ($cc[$x] == 0)
0719 {
0720 $need_all_zeros = 1;
0721
0722 $S->push_out("ST($c->{name}, $x, all_zeros);");
0723 }
0724 else
0725 {
0726 $cc[$x] = $S->{tick} + 4;
0727 push @to_store, $x;
0728 }
0729 }
0730
0731
0732 while (1)
0733 {
0734 last unless @to_store;
0735 my $s = $to_store[0];
0736 last if $S->{tick} < $cc[$s];
0737
0738 $S->store($c, $s);
0739 shift @to_store;
0740 ++$S->{tick};
0741 }
0742
0743 }
0744
0745 $S->push_out("") unless $i + 1 == $a->{M} and $k + 1 == $a->{N};
0746 }
0747 }
0748
0749 for my $s (@to_store)
0750 {
0751 $S->store($c, $s);
0752
0753 ++$S->{tick};
0754 }
0755
0756 $S->handle_all_zeros_ones($need_all_zeros, $need_all_ones);
0757
0758 for (@{$S->{out}})
0759 {
0760 print $S->{prefix} unless /^$/;
0761 print;
0762 print "\n";
0763 }
0764
0765 $S->delete_temporaries();
0766 }
0767
0768
0769
0770 sub dump_multiply_std_and_intrinsic
0771 {
0772 my ($S, $fname, $a, $b, $c) = @_;
0773
0774 unless ($fname eq '-')
0775 {
0776 open FF, ">$fname";
0777 select FF;
0778 }
0779
0780 print <<"FNORD";
0781
0782
0783 for (int n = 0; n < N; n += MPLEX_INTRINSICS_WIDTH_BYTES / sizeof(T))
0784 {
0785 FNORD
0786
0787 $S->multiply_intrinsic($a, $b, $c);
0788
0789 print <<"FNORD";
0790 }
0791
0792
0793
0794
0795 for (int n = 0; n < N; ++n)
0796 {
0797 FNORD
0798
0799 $S->multiply_standard($a, $b, $c);
0800
0801 print <<"FNORD";
0802 }
0803
0804 FNORD
0805
0806 unless ($fname eq '-')
0807 {
0808 close FF;
0809 select STDOUT;
0810 }
0811 }
0812
0813
0814
0815 sub dump_multiply_std_and_intrinsic_and_gpu
0816 {
0817 my ($S, $fname, $a, $b, $c) = @_;
0818
0819 unless ($fname eq '-')
0820 {
0821 open FF, ">$fname";
0822 select FF;
0823 }
0824
0825 print <<"FNORD";
0826
0827
0828
0829 for (int n = 0; n < N; n += MPLEX_INTRINSICS_WIDTH_BYTES / sizeof(T))
0830 {
0831 FNORD
0832
0833 $S->multiply_intrinsic($a, $b, $c);
0834
0835 print <<"FNORD";
0836 }
0837
0838
0839
0840
0841 for (int n = 0; n < N; ++n)
0842 {
0843 FNORD
0844
0845 $S->multiply_standard($a, $b, $c);
0846
0847 print <<"FNORD";
0848 }
0849
0850
0851 FNORD
0852 $S->multiply_gpu($a, $b, $c);
0853 print <<"FNORD";
0854
0855 FNORD
0856
0857 unless ($fname eq '-')
0858 {
0859 close FF;
0860 select STDOUT;
0861 }
0862 }
0863
0864
0865
0866
0867
0868
0869
0870 1;