77import java .io .IOException ;
88import java .io .ObjectInputStream ;
99import java .io .ObjectOutputStream ;
10+ import java .io .Serializable ;
1011import java .io .UnsupportedEncodingException ;
1112import java .math .BigInteger ;
1213import java .security .SecureRandom ;
@@ -554,8 +555,10 @@ public Signature sign(String M, String IDA, SM2KeyPair keyPair) {
554555 r = e .add (x1 );
555556 r = r .mod (n );
556557 } while (r .equals (BigInteger .ZERO ) || r .add (k ).equals (n ));
558+
557559 BigInteger s = ((keyPair .getPrivateKey ().add (BigInteger .ONE ).modInverse (n ))
558560 .multiply ((k .subtract (r .multiply (keyPair .getPrivateKey ()))).mod (n ))).mod (n );
561+
559562 return new Signature (r , s );
560563 }
561564
@@ -577,11 +580,14 @@ public boolean verify(String M, Signature signature, String IDA, ECPoint aPublic
577580 return false ;
578581 if (!between (signature .s , BigInteger .ONE , n ))
579582 return false ;
583+
580584 byte [] M_ = join (ZA (IDA , aPublicKey ), M .getBytes ());
581585 BigInteger e = new BigInteger (1 , sm3hash (M_ ));
582586 BigInteger t = signature .r .add (signature .s ).mod (n );
587+
583588 if (t .equals (BigInteger .ZERO ))
584589 return false ;
590+
585591 ECPoint p1 = G .multiply (signature .s ).normalize ();
586592 ECPoint p2 = aPublicKey .multiply (t ).normalize ();
587593 BigInteger x1 = p1 .add (p2 ).normalize ().getXCoord ().toBigInteger ();
@@ -620,6 +626,26 @@ private static byte[] KDF(byte[] Z, int klen) {
620626 return null ;
621627 }
622628
629+ /**
630+ * 传输实体类
631+ *
632+ * @author Potato
633+ *
634+ */
635+ private static class TransportEntity implements Serializable {
636+ final byte [] R ; //R点
637+ final byte [] S ; //验证S
638+ final byte [] Z ; //用户标识
639+ final byte [] K ; //公钥
640+
641+ public TransportEntity (byte [] r , byte [] s ,byte [] z ,ECPoint pKey ) {
642+ R = r ;
643+ S = s ;
644+ Z =z ;
645+ K =pKey .getEncoded (false );
646+ }
647+ }
648+
623649 /**
624650 * 密钥协商辅助类
625651 *
@@ -629,91 +655,138 @@ private static byte[] KDF(byte[] Z, int klen) {
629655 public static class KeyExchange {
630656 BigInteger rA ;
631657 ECPoint RA ;
658+ ECPoint V ;
659+ byte [] Z ;
632660 byte [] key ;
661+
662+ String ID ;
633663 SM2KeyPair keyPair ;
634664
635- public KeyExchange (SM2KeyPair keyPair ) {
665+ public KeyExchange (String ID ,SM2KeyPair keyPair ) {
666+ this .ID =ID ;
636667 this .keyPair = keyPair ;
668+ this .Z =ZA (ID , keyPair .getPublicKey ());
637669 }
638670
639671 /**
640672 * 密钥协商发起第一步
641673 *
642674 * @return
643675 */
644- public byte [] keyExchange_1 () {
676+ public TransportEntity keyExchange_1 () {
645677 rA = random (n );
646678 // rA=new BigInteger("83A2C9C8 B96E5AF7 0BD480B4 72409A9A 327257F1
647679 // EBB73F5B 073354B2 48668563".replace(" ", ""),16);
648680 RA = G .multiply (rA ).normalize ();
649- return RA .getEncoded (false );
681+ return new TransportEntity ( RA .getEncoded (false ), null , Z , keyPair . getPublicKey () );
650682 }
651683
652684 /**
653685 * 密钥协商响应方
654686 *
655- * @param IDA
656- * @param IDB
657- * @param aPublicKey
658- * @param RAbytes
687+ * @param entity 传输实体
659688 * @return
660689 */
661- public byte [] keyExchange_2 (String IDA , String IDB , ECPoint aPublicKey , byte [] RAbytes ) {
690+ public TransportEntity keyExchange_2 (TransportEntity entity ) {
662691 BigInteger rB = random (n );
663692 // BigInteger rB=new BigInteger("33FE2194 0342161C 55619C4A 0C060293
664693 // D543C80A F19748CE 176D8347 7DE71C80".replace(" ", ""),16);
665694 ECPoint RB = G .multiply (rB ).normalize ();
695+
696+ this .rA =rB ;
697+ this .RA =RB ;
666698
667699 BigInteger x2 = RB .getXCoord ().toBigInteger ();
668700 x2 = _2w .add (x2 .and (_2w .subtract (BigInteger .ONE )));
669701
670702 BigInteger tB = keyPair .getPrivateKey ().add (x2 .multiply (rB )).mod (n );
671- ECPoint RA = curve .decodePoint (RAbytes ).normalize ();
672-
703+ ECPoint RA = curve .decodePoint (entity . R ).normalize ();
704+
673705 BigInteger x1 = RA .getXCoord ().toBigInteger ();
674706 x1 = _2w .add (x1 .and (_2w .subtract (BigInteger .ONE )));
675707
708+ ECPoint aPublicKey =curve .decodePoint (entity .K ).normalize ();
676709 ECPoint temp = aPublicKey .add (RA .multiply (x1 ).normalize ()).normalize ();
677710 ECPoint V = temp .multiply (ecc_bc_spec .getH ().multiply (tB )).normalize ();
678711 if (V .isInfinity ())
679712 throw new IllegalStateException ();
680-
681- byte [] KB = KDF (join (V .getXCoord ().toBigInteger ().toByteArray (), V .getYCoord ().toBigInteger ().toByteArray (),
682- ZA (IDA , aPublicKey ), ZA (IDB , keyPair .getPublicKey ())), 16 );
713+ this .V =V ;
714+
715+ byte [] xV = V .getXCoord ().toBigInteger ().toByteArray ();
716+ byte [] yV = V .getYCoord ().toBigInteger ().toByteArray ();
717+ byte [] KB = KDF (join (xV , yV , entity .Z , this .Z ), 16 );
683718 key = KB ;
684719 System .out .print ("协商得B密钥:" );
685720 printHexString (KB );
686- return RB .getEncoded (false );
721+ byte [] sB = sm3hash (new byte [] { 0x02 }, yV ,
722+ sm3hash (xV , entity .Z , this .Z , RA .getXCoord ().toBigInteger ().toByteArray (),
723+ RA .getYCoord ().toBigInteger ().toByteArray (), RB .getXCoord ().toBigInteger ().toByteArray (),
724+ RB .getYCoord ().toBigInteger ().toByteArray ()));
725+ return new TransportEntity (RB .getEncoded (false ), sB ,this .Z ,keyPair .getPublicKey ());
687726 }
688727
689728 /**
690729 * 密钥协商发起方第二步
691730 *
692- * @param IDA
693- * @param IDB
694- * @param bPublicKey
695- * @param RBbytes
731+ * @param entity 传输实体
696732 */
697- public void keyExchange_3 (String IDA , String IDB , ECPoint bPublicKey , byte [] RBbytes ) {
733+ public TransportEntity keyExchange_3 (TransportEntity entity ) {
698734 BigInteger x1 = RA .getXCoord ().toBigInteger ();
699735 x1 = _2w .add (x1 .and (_2w .subtract (BigInteger .ONE )));
700-
736+
701737 BigInteger tA = keyPair .getPrivateKey ().add (x1 .multiply (rA )).mod (n );
702- ECPoint RB = curve .decodePoint (RBbytes ).normalize ();
738+ ECPoint RB = curve .decodePoint (entity . R ).normalize ();
703739
704740 BigInteger x2 = RB .getXCoord ().toBigInteger ();
705741 x2 = _2w .add (x2 .and (_2w .subtract (BigInteger .ONE )));
706-
742+
743+ ECPoint bPublicKey =curve .decodePoint (entity .K ).normalize ();
707744 ECPoint temp = bPublicKey .add (RB .multiply (x2 ).normalize ()).normalize ();
708745 ECPoint U = temp .multiply (ecc_bc_spec .getH ().multiply (tA )).normalize ();
709746 if (U .isInfinity ())
710747 throw new IllegalStateException ();
748+ this .V =U ;
711749
712- byte [] KA = KDF (join (U .getXCoord ().toBigInteger ().toByteArray (), U .getYCoord ().toBigInteger ().toByteArray (),
713- ZA (IDA , keyPair .getPublicKey ()), ZA (IDB , bPublicKey )), 16 );
750+ byte [] xU = U .getXCoord ().toBigInteger ().toByteArray ();
751+ byte [] yU = U .getYCoord ().toBigInteger ().toByteArray ();
752+ byte [] KA = KDF (join (xU , yU ,
753+ this .Z , entity .Z ), 16 );
714754 key = KA ;
715755 System .out .print ("协商得A密钥:" );
716756 printHexString (KA );
757+ byte [] s1 = sm3hash (new byte [] { 0x02 }, yU ,
758+ sm3hash (xU , this .Z , entity .Z , RA .getXCoord ().toBigInteger ().toByteArray (),
759+ RA .getYCoord ().toBigInteger ().toByteArray (), RB .getXCoord ().toBigInteger ().toByteArray (),
760+ RB .getYCoord ().toBigInteger ().toByteArray ()));
761+ if (Arrays .equals (entity .S , s1 ))
762+ System .out .println ("B->A 密钥确认成功" );
763+ else
764+ System .out .println ("B->A 密钥确认失败" );
765+ byte [] sA = sm3hash (new byte [] { 0x03 }, yU ,
766+ sm3hash (xU , this .Z , entity .Z , RA .getXCoord ().toBigInteger ().toByteArray (),
767+ RA .getYCoord ().toBigInteger ().toByteArray (), RB .getXCoord ().toBigInteger ().toByteArray (),
768+ RB .getYCoord ().toBigInteger ().toByteArray ()));
769+
770+ return new TransportEntity (RA .getEncoded (false ), sA ,this .Z ,keyPair .getPublicKey ());
771+ }
772+
773+ /**
774+ * 密钥确认最后一步
775+ *
776+ * @param entity 传输实体
777+ */
778+ public void keyExchange_4 (TransportEntity entity ) {
779+ byte [] xV = V .getXCoord ().toBigInteger ().toByteArray ();
780+ byte [] yV = V .getYCoord ().toBigInteger ().toByteArray ();
781+ ECPoint RA = curve .decodePoint (entity .R ).normalize ();
782+ byte [] s2 = sm3hash (new byte [] { 0x03 }, yV ,
783+ sm3hash (xV , entity .Z , this .Z , RA .getXCoord ().toBigInteger ().toByteArray (),
784+ RA .getYCoord ().toBigInteger ().toByteArray (), this .RA .getXCoord ().toBigInteger ().toByteArray (),
785+ this .RA .getYCoord ().toBigInteger ().toByteArray ()));
786+ if (Arrays .equals (entity .S , s2 ))
787+ System .out .println ("A->B 密钥确认成功" );
788+ else
789+ System .out .println ("A->B 密钥确认失败" );
717790 }
718791 }
719792
@@ -757,14 +830,15 @@ public static void main(String[] args) throws UnsupportedEncodingException {
757830 System .out .println ("-----------------密钥协商-----------------" );
758831 String aID = "AAAAAAAAAAAAA" ;
759832 SM2KeyPair aKeyPair = sm02 .generateKeyPair ();
760- KeyExchange aKeyExchange = new KeyExchange (aKeyPair );
833+ KeyExchange aKeyExchange = new KeyExchange (aID , aKeyPair );
761834
762835 String bID = "BBBBBBBBBBBBB" ;
763836 SM2KeyPair bKeyPair = sm02 .generateKeyPair ();
764- KeyExchange bKeyExchange = new KeyExchange (bKeyPair );
765- byte [] RAbytes = aKeyExchange .keyExchange_1 ();
766- byte [] RBbytes = bKeyExchange .keyExchange_2 (aID , bID , aKeyPair .getPublicKey (), RAbytes );
767- aKeyExchange .keyExchange_3 (aID , bID , bKeyPair .getPublicKey (), RBbytes );
837+ KeyExchange bKeyExchange = new KeyExchange (bID ,bKeyPair );
838+ TransportEntity entity1 = aKeyExchange .keyExchange_1 ();
839+ TransportEntity entity2 = bKeyExchange .keyExchange_2 (entity1 );
840+ TransportEntity entity3 = aKeyExchange .keyExchange_3 (entity2 );
841+ bKeyExchange .keyExchange_4 (entity3 );
768842 }
769843
770844 public static class Signature {
0 commit comments